首页 > 解决方案 > Scikit learn:忘记之前的训练数据

问题描述

在 scikit learn 我有一个模型(在我的例子中是一个线性模型)

clf = linear_model.LinearRegression()

我可以用一些数据训练这个模型

clf.fit(x1,y1)

但如果我再次打电话fit,它将继续训练模型。

clf.fit(x2,y2)

现在 clf 是一个用 (x1,y1) 和 (x2,y2) 训练的模型

如果我想从 0 开始训练,我可以通过重新定义重新创建模型clf

clf = linear_model.LinearRegression()
clf.fit(x1,y1)
# save the model
# ...
clf = linear_model.LinearRegression()
clf.fit(x2,y2)

但是我不想再次定义 clf :

基本上之前选择了回归量的类型,例如:

if params.linear_algorithm == 'least_squares':
    clf = linear_model.LinearRegression()
elif params.linear_algorithm == 'ridge':
    clf = linear_model.Ridge()
elif params.linear_algorithm == 'lasso':
    clf = linear_model.Lasso()

所以我不想在我的 train 函数中重新定义clf所有条件块,而是我只想clf从以前的训练中清除它并重用它来训练另一组数据。

clf 是否有一种方法来清理到目前为止所学的内容,所以当我调用 clf.fit(x2,y2) 时只对这些数据进行训练?

编辑:你们是对的,培训每次都被覆盖。

我的问题是我将模型保存在字典中,它只是引用 clf,所以每次重新训练 clf 时,所有以前的保存都会改变。

每次重新定义 clf 都会创建一个新对象,因此每个保存点现在都是不同的模型

例子

for i in range(3):
   # get the x and y
   # ...
   clf.fit(x,y)
   model[i] = clf

知道如何每次保存不同的模型而不是将所有模型 [i] 指向同一个 clf 吗?

标签: pythonscikit-learn

解决方案


你的假设是错误的。根据Scikit-Learn 文档

多次调用 fit() 将覆盖之前任何 fit() 学到的内容。

因此,您可以安全地使用您的代码,它将实现您的需求。


推荐阅读