python - 为什么列转换器给出压缩的稀疏行格式?
问题描述
当我在此数据集上使用 OneHotEncoder 和列转换器时,它会产生压缩的稀疏行格式。编码后,我想使用 train_test_split 拆分数据,但显示此错误:
Singleton array array(<32561x105 sparse matrix of type '<class 'numpy.float64'>'
with 394963 stored elements in Compressed Sparse Row format>,
dtype=object) cannot be considered a valid collection.
首先我处理这样的缺失值
from sklearn.impute import SimpleImputer
imputer_nominal = SimpleImputer(missing_values = np.nan, strategy = 'most_frequent')
imputer_numerical = SimpleImputer(missing_values = np.nan, strategy = 'mean')
imputer_nominal.fit(x[:,[1,3,5,6,7,8,9,13]])
x[:,[1,3,5,6,7,8,9,13]] = imputer_nominal.transform(x[:,[1,3,5,6,7,8,9,13]])
imputer_numerical.fit(x[:,[0,2,4,10,11,12]])
x[:,[0,2,4,10,11,12]] = imputer_numerical.transform(x[:,[0,2,4,10,11,12]])
然后我对数据进行编码:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
ct = ColumnTransformer(transformers = [('encoder', OneHotEncoder(), [1,3,5,6,7,8,9,13])], remainder = 'passthrough')
x = np.array(ct.fit_transform(x))
当我输出 numpy 数组“x”时,它看起来像这样,这是一种压缩的稀疏行格式
(0, 6) 1.0
(0, 17) 1.0
(0, 28) 1.0
(0, 31) 1.0
(0, 46) 1.0
(0, 55) 1.0
(0, 57) 1.0
(0, 96) 1.0
(0, 99) 39.0
在此之后,我尝试拆分数据并显示上述错误。我之前使用过 uesd 列转换器和 OneHotEncoder,但我不知道这个出了什么问题。另外,我不在此代码中的任何地方使用 scipy 库。
解决方案
我的 ColumnTransformer 也创建了压缩的稀疏数据。我在函数中设置了 sparse_threshold=0 。其默认值为 0.3。这似乎是一个新的属性/值,因为我看过 ColumnTransformer 的视频不需要它并创建相同的结果。如果有帮助,这是我的代码。
原始代码:
#This data is for a car sales CSV
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
categorical_features = ["Make", "Colour", "Doors"]
one_hot = OneHotEncoder()
transformer = ColumnTransformer([('one_hot',
one_hot,
categorical_features)],
remainder ='passthrough',
sparse_threshold=0)
transformed_X = transformer.fit_transform(X)
transformed_X[:1], pd.DataFrame(transformed_X).head()
原始输出:
(<1x16 sparse matrix of type '<class 'numpy.float64'>'
with 4 stored elements in Compressed Sparse Row format>,
0
0 (0, 1)\t1.0\n (0, 9)\t1.0\n (0, 12)\t1.0\n...
1 (0, 0)\t1.0\n (0, 6)\t1.0\n (0, 13)\t1.0\n...
2 (0, 1)\t1.0\n (0, 9)\t1.0\n (0, 12)\t1.0\n...
3 (0, 3)\t1.0\n (0, 9)\t1.0\n (0, 12)\t1.0\n...
4 (0, 2)\t1.0\n (0, 6)\t1.0\n (0, 11)\t1.0\n...)
带有 Sparse_thresh 的代码:
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
categorical_features = ["Make", "Colour", "Doors"]
one_hot = OneHotEncoder()
transformer = ColumnTransformer([('one_hot',
one_hot,
categorical_features)],
remainder ='passthrough',
sparse_threshold=0)
transformed_X = transformer.fit_transform(X)
#put in data frame for viewing
transformed_X[:1], pd.DataFrame(transformed_X).head()
sparse_threshold=0 的输出代码(第 15 列是里程表值):
(array([[0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00,
0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 1.0000e+00,
0.0000e+00, 0.0000e+00, 1.0000e+00, 0.0000e+00, 0.0000e+00,
3.5431e+04]]),
0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 \
0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0
1 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0
2 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0
3 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 1.0 0.0 0.0
4 0.0 0.0 1.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0
15
0 35431.0
1 192714.0
2 84714.0
3 154365.0
4 181577.0 )
推荐阅读
- android - 如何避免与 2 个可观察订阅者的竞争条件?
- javascript - onClick 的箭头功能正在将所有实例的参数更新为最新使用情况
- r - 如何使用列名列表进行分组和汇总?
- excel - 使用 Excel VBA,如何从工作簿中自动生成图表标题列表?
- android - 为什么我不能从变量中删除 div 标签?
- reactjs - 为什么 ctx.strokeRect() 的行为与顺序 ctx.rect() ctx.stroke() 调用不同
- google-analytics - 如果我已经在使用 Google 跟踪代码管理器来引入 Google Analytics,我是否需要全局站点代码 (gtag.js) 才能与 Adwords 一起使用?
- fortran - Fortran 中的指针转换
- kubernetes - 我可以将容器的环境变量设置为集群中服务的 clusterIP 的值吗?
- docker - 如何从 Dockerfile 多行回显 json