首页 > 解决方案 > Python ValueError: n_splits=3 不能大于每个类的成员数

问题描述

我正在做人脸识别项目,我有两个人,每个人有 2 张脸

1. personA
    image1.jpg
    image2.jpg


2. personB
    image1.jpg
    image2.jpg

我正在尝试在上述数据集的人脸嵌入上训练模型,如下所示:

params = {"C": [0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0], "gamma": [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1]}
model = GridSearchCV(SVC(kernel="rbf", gamma="auto", probability=True), params, cv=3, n_jobs=-1)
model.fit(data["embeddings"], labels)

data["embeddings"]和的长度在labels哪里4data["embeddings']包含personA、personB的人脸嵌入的ndarray

data['embeddings'] = [
                         [0.02331057, -0.01995077, ..], 
                         [-0.00034041,  0.02753334, ..], 
                         [0.02454563, -0.03797123, ...], 
                         [0.10561685, -0.08444008, ...]
                     ]

labels = [0 0 1 1]

但我在以下错误model.fit(data["embeddings"], labels)

ValueError: n_splits=3 cannot be greater than the number of members in each class.

我无法理解这个错误。谁能解释一下这个问题,我该如何解决?

标签: pythonmachine-learningscikit-learncross-validation

解决方案


仔细阅读,错误信息清晰且不言自明;它只是告诉您,由于您的每个班级总共只有两 (2) 个样本,因此您不能进行 3 折交叉验证。这将需要您的每个班级至少3 个样本。

我想它应该可以正常工作而cv=2不会引发任何错误,但是您的整个方法(即只有 4 个样本的数据集)似乎非常值得怀疑。


推荐阅读