python - 如何在没有滑雪套件学习的情况下为 K 折交叉验证创建训练集?
问题描述
我有一个包含 95 行和 9 列的数据集,并且想要进行 5 折交叉验证。在训练中,前 8 列(特征)用于预测第 9 列。我的测试集是正确的,但是当它应该只有 8 列时,我的 x 训练集的大小为 (4,19,9),而当它应该有 19 行时,我的 y 训练集的大小为 (4,9)。我是否错误地索引了子数组?
kdata = data[0:95,:] # Need total rows to be divisible by 5, so ignore last 2 rows
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns
for i in range (k-1):
xtest = folds[i][:,0:7] # Set ith fold to be test
ytest = folds[i][:,8]
new_folds = np.delete(folds,i,0)
xtrain = new_folds[:][:][0:7] # training set is all folds, all rows x 8 cols
ytrain = new_folds[:][:][8] # training y is all folds, all rows x 1 col
解决方案
欢迎来到堆栈溢出。
创建新折叠后,您需要使用np.row_stack()
.
另外,我认为您在 Python 或 Numpy 中对数组进行了错误的切片,[inclusive:exclusive]
因此切片行为是,当您指定切片时,[0:7]
您只取 7 列,而不是您想要的 8 个特征列。
同样,如果您在 for 循环中指定 5 折,它应该是range(k)
which 给你[0,1,2,3,4]
而不是range(k-1)
which 只给你[0,1,2,3]
.
修改后的代码如下:
folds = np.array_split(kdata, k) # each fold is 19 rows x 9 columns
np.random.shuffle(kdata) # Shuffle all rows
folds = np.array_split(kdata, k)
for i in range (k):
xtest = folds[i][:,:8] # Set ith fold to be test
ytest = folds[i][:,8]
new_folds = np.row_stack(np.delete(folds,i,0))
xtrain = new_folds[:, :8]
ytrain = new_folds[:,8]
# some print functions to help you debug
print(f'Fold {i}')
print(f'xtest shape : {xtest.shape}')
print(f'ytest shape : {ytest.shape}')
print(f'xtrain shape : {xtrain.shape}')
print(f'ytrain shape : {ytrain.shape}\n')
这将为您打印出折叠和所需形状的训练和测试集:
Fold 0
xtest shape : (19, 8)
ytest shape : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
Fold 1
xtest shape : (19, 8)
ytest shape : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
Fold 2
xtest shape : (19, 8)
ytest shape : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
Fold 3
xtest shape : (19, 8)
ytest shape : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
Fold 4
xtest shape : (19, 8)
ytest shape : (19,)
xtrain shape : (76, 8)
ytrain shape : (76,)
推荐阅读
- r - 如何正确访问嵌套 R 列表中的项目?
- python - 如何使用 Python 从带有身份验证密钥的 URL 下载 JSON 文件
- swift - 如何解决firebase的异步问题
- javascript - 完成Ajax后,变量Js无法在href标签中呈现
- php - 获取 Laravel 中的所有 ENV 变量
- mysql - MySQL 从列中选择和删除 JSON 字符
- javascript - Vue错误:在严格模式代码中,函数只能在顶层或块内声明
- java - 安卓 NSD。尝试解决时出现错误
- makefile - 如何从内核源代码树中编译工具和示例?(例如 bpftool、bpf 样本)
- ios - 同时使用 2 个 SDK(CALL KIT 和 VoIP SDK)管理来电?