python - 输入 Keras 批量数据,然后进行多个 epoch 的训练
问题描述
我有一个大约 980 行的数据集(每行有 12 个特征和一个类别变量),我想对这个集合进行 n 折交叉验证。为此,我目前正在做的是将数据集划分为n
集合,保存当前委派的测试集的索引(我从 0 开始并迭代它n
),然后通过train_on_batch()
. 我似乎遇到的问题是,我不能以这种方式进行时代。因为fit()
,我需要
- 将数据集划分为
n
集合 - 委托测试集的索引
- 然后通过将除测试集以外的所有数据集重新连接在一起来创建训练数据集。
步骤 (3) 似乎非常愚蠢,这就是我尝试这样做的原因 - 是否可以手动将批次输入模型然后训练x
epoch?
当前代码(仅相关部分):
def partition(self, chunks):
temp_data = self.data
temp_data[temp_data=="Normal"] = 0
temp_data[temp_data=="Abnormal"] = 1
partitions = np.array_split(temp_data, 10)
return partitions
def nfold_cv(self, chunks, ep_each):
parts = self.partition(chunks)
testPartition = 0
#Lets assign first set as test first, then use other 9 as train
while testPartition < chunks:
#Build new model for each iteration
self.model = Sequential()
self.model.add(Dense(units=self.neurons, input_shape=(12,)))
for j in range(self.hidden):
self.model.add(Dense(units=self.neurons, activation=self.hidden_activation, input_dim=(996,)))
self.model.add(Dense(units=1, activation=self.end_activation))
self.model.compile(loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
#Then iterate through all partitions and train on all that arent assigned as test partition
for i in range(0, len(parts)):
if testPartition != i:
tempPart = parts[i].transpose()
train_x = tempPart[1:,].transpose()
train_y = tempPart[0]
self.history = self.model.train_on_batch(train_x, train_y, reset_metrics=False )
#Network is trained, now predict on assigned test partition
test = parts[testPartition].transpose()
test_x = test[1:,].transpose()
test_y = test[0]
predicted = self.model.predict(test_x)
score = self.model.evaluate(test_x, test_y, verbose=2)
#Accuracy of this
print("CV on test set at " + str(testPartition) + " is " + str(score[1]))
#Assign next partition as test
testPartition = testPartition + 1
这会创建一个输出,例如
CV on test set at 0 is 0.6000000238418579
CV on test set at 1 is 0.5400000214576721
CV on test set at 2 is 0.4699999988079071
CV on test set at 3 is 0.5099999904632568
CV on test set at 4 is 0.5199999809265137
CV on test set at 5 is 0.5600000023841858
CV on test set at 6 is 0.5858585834503174
CV on test set at 7 is 0.5252525210380554
CV on test set at 8 is 0.4747474789619446
CV on test set at 9 is 0.5656565427780151
n=10
暂时使用
解决方案
推荐阅读
- c++ - 基于任务的应用程序何时切换上下文?
- pandas - 熊猫群比。按大洲对covid19病例进行分组
- jestjs - 方法“模拟”意味着在 1 个节点上运行。找到了 0 个。在尝试模拟 onChange 事件时
- manim - 为什么这部动画持续时间这么短?
- angular - Angular:错误:将 ngFor 绑定到数组(管道参数无效)
- powerbi - Power BI 在列上使用 $expand
- git - 使用 Git 部署时 Heroku 中的错误 H14
- asp.net - 在 VB.NET 中动态添加的事件处理程序未触发
- javascript - 为什么 "asdf".replace(/.*/g, "x") == "xx"?
- java - 使用视图模型将数据从活动发送到片段