首页 > 解决方案 > 如何使用示例所属的类别列表作为特征解决分类问题?

问题描述

其中一项功能如下所示:

1       170,169,205,174,173,246,247,249,380,377,383,38...
2       448,104,239,277,276,99,154,155,76,412,139,333,...
3       268,422,419,124,1,17,431,343,341,435,130,331,5...
4       50,53,449,106,279,420,161,74,123,364,231,18,23...
5       170,169,205,174,173,246,247,249,380,377,383,38...

它告诉我们示例属于哪些类别。在解决分类问题时我应该如何使用它?

我试过使用虚拟变量,

df=df.join(features['cat'].str.get_dummies(',').add_prefix('contains_'))

但是我们不知道训练集中没有提到的其他类别在哪里,所以,我不知道如何对所有对象进行预处理。

标签: pandasmachine-learningscikit-learn

解决方案


那很有意思。我不知道str.get_dummies,但也许我可以帮你解决剩下的问题。

你基本上有两个问题:

  1. 您稍后获得的类别集包含训练模型时未知的类别。你必须在以后摆脱这些。

  2. 您稍后获得的类别集不包含所有类别。你必须确保,你也为他们生成假人。

问题 1:过滤掉未知/不需要的类别

第一个问题很容易解决:

# create a set of all categories, you want to allow
# either definie it as a fixed set, or extract it from your
# column like this (the output of the map is actually irrelevant)
# the result will be in valid_categories
valid_categories= set()
df['categories'].str.split(',').map(valid_categories.update)

# now if you want to normalize your data before you do the
# dummy encoding, you can cleanse the data by
# splitting it, creating an intersection and then joining
# it back again to get a string on which you can work with
# str.get_dummies
df['categories'].str.split(',').map(lambda l: valid_categories.intersection(l)).str.join(',')

问题 2:为所有已知类别生成假人

第二个问题可以通过添加一个虚拟行来解决,该行包含所有类别,例如df.append在您调用之前get_dummies并在之后删除它get_dummies

# e.g. you can do it like this
# get a new index value to
# be able to remove the row later
# (this only works if you have
# a numeric index)
dummy_index= df.index.max()+1

# assign the categories
# 
df.loc[dummy_index]= {'id':999, 'categories': ','.join(valid_categories)}
# now do the processing steps 
# mentioned in the section above
# then create the dummies
# after that remove the dummy line
# again
df.drop(labels=[dummy_index], inplace=True)

例子:

import io

raw= """id      categories
1       170,169,205,174,173,246,247
2       448,104,239,277,276,99,154
3       268,422,419,124,1,17,431,343
4       50,53,449,106,279,420,161,74
5       170,169,205,174,173,246,247"""
df= pd.read_fwf(io.StringIO(raw))

valid_categories= set()
df['categories'].str.split(',').map(valid_categories.update)
# remove 154 and 170 for demonstration purposes
valid_categories.remove('170')
valid_categories.remove('154')

df['categories'].str.split(',').map(lambda l: valid_categories.intersection(l)).str.join(',').str.get_dummies(',')
Out[622]: 
   1  104  106  124  161  169  17  173  174  205  239  246  247  268  276  277  279  343  419  420  422  431  448  449  50  53  74  99
0  0    0    0    0    0    1   0    1    1    1    0    1    1    0    0    0    0    0    0    0    0    0    0    0   0   0   0   0
1  0    1    0    0    0    0   0    0    0    0    1    0    0    0    1    1    0    0    0    0    0    0    1    0   0   0   0   1
2  1    0    0    1    0    0   1    0    0    0    0    0    0    1    0    0    0    1    1    0    1    1    0    0   0   0   0   0
3  0    0    1    0    1    0   0    0    0    0    0    0    0    0    0    0    1    0    0    1    0    0    0    1   1   1   1   0
4  0    0    0    0    0    1   0    1    1    1    0    1    1    0    0    0    0    0    0    0    0    0    0    0   0   0   0   0

您可以看到,没有 154 和 170 的列。


推荐阅读