python - Scikit-Learn OneHotEncoder 无法正常工作?
问题描述
我有一个如下所示的数据列表:
[['ocak' 2017]
['subat' 2017]
['mart' 2017]
['nisan' 2017]
['mayis' 2017]
['haziran' 2017]
['temuz' 2017]
['agustos' 2017]
['eylul' 2017]
['ekim' 2017]
['kasim' 2017]
['aralik' 2017]
['ocak' 2018]
['subat' 2018]
['mart' 2018]
['nisan' 2018]]
我想使用 OneHotEncoder 对列表的字符串部分('subat'、'mart' 等)进行编码,以便在我的回归模型中使用它。
我使用的代码是这样的:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder
ct = ColumnTransformer(transformers=[('encoder', OneHotEncoder(), [0])], remainder='passthrough')
X = np.array(ct.fit_transform(X))
但我的输出是这样的:
(0, 9) 1.0
(0, 13) 2017.0
(1, 10) 1.0
(1, 13) 2017.0
(2, 6) 1.0
(2, 13) 2017.0
(3, 8) 1.0
(3, 13) 2017.0
(4, 7) 1.0
(4, 13) 2017.0
(5, 4) 1.0
(5, 13) 2017.0
(6, 12) 1.0
(6, 13) 2017.0
(7, 0) 1.0
(7, 13) 2017.0
(8, 3) 1.0
这是 train_test_split 类不能接受的..
我需要这样的输出
[1.0 0.0 0.0 2017]
我怎样才能让它像上面一样给我输出。还是我的代码或数据集有问题?
解决方案
OneHotEncoder
默认情况下返回一个稀疏矩阵,所以当你用 包装返回值时np.array
,你会得到一个不想要的表示。你有两个选择:
- 传给ie
sparse=False
,OneHotEncoder
ct = ColumnTransformer(transformers=[('encoder', OneHotEncoder(sparse=False), [0])], remainder='passthrough')
- 或将结果转换为 numpy 数组
toarray
,即
X = ct.fit_transform(X).toarray()
采用第二种方式(我将其包裹起来pd.DataFrame
以便于检查结果):
>>> pd.DataFrame(X)
0 1 2 3 4 5 6 7 8 9 10 11 12
0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 2017.0
1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 2017.0
2 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2017.0
3 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 2017.0
4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 2017.0
5 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
6 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 2017.0
7 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
8 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
9 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
10 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
11 0.0 1.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 2017.0
12 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 2018.0
13 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 2018.0
14 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 0.0 0.0 2018.0
15 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1.0 0.0 0.0 0.0 2018.0
推荐阅读
- bash - 如何通过循环遍历 TSV 将选项分配给下标?
- google-cloud-platform - 创建实例模板时的 Terraform GCP,获取源图像的相对路径时出错
- python-3.x - 如何更改python环境
- javascript - 避免 keyup 输入 ajax 调用的重复结果;杀死以前的ajax调用
- java - 外部化属性文件 - Widlfly & Struts 1.x
- oracle - 将 SQL 脚本作为 Windows 批处理文件运行
- c# - 拆分多个 linq include()
- android - 如何使用 Kotlin 从非活动类中查找 ListView?
- php - 如何处理内联 SVG 的重复 ID?
- angularjs - angularjs 1.7 选择下拉菜单显示了一个额外的选项值