python - 多个 scikit 学习管道的奇怪行为
问题描述
我正在使用 sklearn 训练模型,并且我的训练序列需要运行两个不同的特征提取管道。
由于某种原因,每个管道都可以毫无问题地拟合数据,并且当它们按顺序出现时,它们也可以毫无问题地转换数据。
然而,当第二条管道已经安装后调用第一条管道时,第一条管道已被更改,这会导致尺寸不匹配错误。
在下面的代码中,您可以重新创建问题(我已经大大简化了它,实际上我的两个管道使用不同的参数,但这是一个最小可重现的示例)。
from sklearn.pipeline import Pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd
vectorizer = CountVectorizer()
data1 = ['foo bar', 'a foo bar duck', 'goose goose']
data2 = ['foo', 'duck duck swan', 'goose king queen goose']
pipeline1 = Pipeline([('vec', vectorizer),('svd', TruncatedSVD(n_components = 3))]).fit(data1)
print(pipeline1.transform(data1))
# Works fine
pipeline2 = Pipeline([('vec', vectorizer),('svd', TruncatedSVD(n_components = 3))]).fit(data2)
print(pipeline2.transform(data2))
# Works fine
print(pipeline1.transform(data1))
# ValueError: dimension mismatch
显然,“pipeline2”的拟合在某种程度上干扰了“pipeline1”,但我不知道为什么。我希望能够同时使用它们。
解决方案
怎么了 :
正如您vectorizer
首先定义的那样,会发生以下情况:
- 你创造
vectorizer
你适合第一个管道:
- 已安装矢量化器,输出暗淡为 (3,4),例如 3 个元素,4 个单词:foo、bar、duck、goose
- svd 适合有 4 列作为输入
你适合第二个管道:
- 再次安装矢量化器,这次使用 6 个单词(例如列)作为输出:foo、duck、swan、goose、king、queen
- 另一个svd已安装,此处不相关
你回调第一个管道:
- 向量器输出一个 (3,6) 矩阵,使用最后一次拟合的单词,例如第二个管道
- svd 已适合接受 4 列作为输入,引发异常。
如何验证这一点:
vectorizer = CountVectorizer()
data1 = ['foo bar', 'a foo bar duck', 'goose goose']
data2 = ['foo', 'duck duck swan', 'goose king queen goose']
pipeline1 = Pipeline([('vec', vectorizer)]).fit(data1)
print(pipeline1.transform(data1).shape)
(3, 4)
# Works fine
pipeline2 = Pipeline([('vec', vectorizer)]).fit(data2)
print(pipeline2.transform(data2).shape)
(3, 6)
# Works fine
# vectorizer = CountVectorizer()
print(pipeline1.transform(data1).shape)
(3, 6)
如何修复它:
您只需在管道中包含矢量化器的定义,如下所示:
from sklearn.pipeline import Pipeline
from sklearn.decomposition import TruncatedSVD
from sklearn.feature_extraction.text import CountVectorizer
import pandas as pd
data1 = ['foo bar', 'a foo bar duck', 'goose goose']
data2 = ['foo', 'duck duck swan', 'goose king queen goose']
pipeline1 = Pipeline([('vec', CountVectorizer()),('svd', TruncatedSVD(n_components = 3))]).fit(data1)
print(pipeline1.transform(data1))
# Works fine
pipeline2 = Pipeline([('vec', CountVectorizer()),('svd', TruncatedSVD(n_components = 3))]).fit(data2)
print(pipeline2.transform(data2))
# Works fine
print(pipeline1.transform(data1))
推荐阅读
- android - 如何实现 ProgressDialog 如下:开始、完成检查、关闭对话框然后打开新活动?
- ruby-on-rails - 在打开和关闭 HTML 标记之间使用 MathJax 时不起作用
- java - mapStruct 停止映射 DTO 的超类字段的配置是什么?
- python - 在python中使用正则表达式获取多个重复行
- mysql - hibernate.ddl.auto=更新创建 DDL 错误 CommandAcceptanceException
- vb.net-2010 - 如何修复 UPDATE 语句中的语法错误
- python - 使用python-telegram-bot按下开始时如何使用内联键盘发送gif?
- c# - 如何在 ASP.NET MVC 视图中解析 JSON 字符串
- javascript - Firebase CloudFirestore 参考数据类型
- firebase - Fluter-FutureBuilder-未获取AsyncSnapshot数据