首页 > 解决方案 > DecisionTreeClassifier fit 方法错误(scikit learn)

问题描述

尝试使用以下方法训练我的DecisionTreeClassifierfit

from sklearn import tree
import skimage

features = []
labels = []

for i in range(5):
    img = skimage.io.imread("circle" + str(i+1) + ".jpg")
    img = skimage.img_as_float(img)
    features.append(img)
    labels.append(0)

    img = skimage.io.imread("square" + str(i+1) + ".jpg")
    img = skimage.img_as_float(img)
    features.append(img)
    labels.append(1)

clf = tree.DecisionTreeClassifier()
clf = clf.fit(features, labels)

接收错误:

ValueError:使用序列设置数组元素。

标签: pythonpython-3.xmachine-learningscikit-learnscikit-image

解决方案


如果您执行以下操作,您将只使用第一个像素值。

features.append(img[0][0])

尝试这个!

import numpy as np
features.append(np.array(img).flatten())

请检查您要附加的数据的维度,以了解实际发生的情况。

print(np.array(img).flatten().shape)

推荐阅读