python - keras中不同批量大小的训练模型
问题描述
我想针对不同的批量大小训练我的模型,即:[64, 128] 我正在使用如下的 for 循环来执行此操作
epoch=2
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs,
callbacks=[early_stopping, chk], validation_data=(x_test, y_test))
对于上面的代码,我的模型产生以下结果:
Epoch 1/2
311/311 [==============================] - 157s 494ms/step - loss: 0.2318 -
f1: 0.0723
Epoch 2/2
311/311 [==============================] - 152s 488ms/step - loss: 0.1402 -
f1: 0.4360
Epoch 1/2
156/156 [==============================] - 137s 877ms/step - loss: 0.1197 -
f1: **0.5450**
Epoch 2/2
156/156 [==============================] - 136s 871ms/step - loss: 0.1132 -
f1: 0.5756
看起来模型在完成批量大小 64 的训练后继续训练,即我想让我的模型从头开始训练下一批,我该怎么做,请指导我。ps:我尝试过的:
epoch=2
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
history = model.fit(x_train, y_train, batch_sizes[i], epochs=epochs,
callbacks=[early_stopping, chk], validation_data=(x_test, y_test))
keras.backend.clear_session()
它也没有奏效
解决方案
您可以编写一个函数来定义一个模型,并且您需要在后续fit
调用之前调用它。如果您的模型包含在 中model
,则权重会在训练期间更新,并且在 fit 调用后保持不变。这就是为什么您需要重新定义模型。这可以帮助你
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
import numpy as np
X = np.random.rand(1000,5)
Y = np.random.rand(1000,1)
def build_model():
model = Sequential()
model.add(Dense(64,input_shape=(X.shape[1],)))
model.add(Dense(Y.shape[1]))
model.compile(loss='mse',optimizer='Adam')
return model
epoch=2
batch_sizes = [128,256]
for i in range(len(batch_sizes)):
model = build_model()
history = model.fit(X, Y, batch_sizes[i], epochs=epoch, verbose=2)
model.save('Model_' + str(batch_sizes[i]) + '.h5')
然后,输出如下所示:
Epoch 1/2
8/8 - 0s - loss: 0.3164
Epoch 2/2
8/8 - 0s - loss: 0.1367
Epoch 1/2
4/4 - 0s - loss: 0.7221
Epoch 2/2
4/4 - 0s - loss: 0.4787
推荐阅读
- angular - 在 Angular 6 中为托管 UI 配置 Amplify 时返回错误请求
- sql - 如何在 TimescaleDB 中的一张表上创建多个连续聚合?
- javascript - 正则表达式电话号码验证,带有空格和最小数字的字符
- javascript - ngtemplateoutlet 在两个不同的循环中
- regex - 如何使用正则表达式提取化学术语
- php - 使用代理和通过代理建立隧道有什么区别
- php - 如何不缓存动态路由
- python - 在 XML 和 Python 中处理属性条件
- c# - 将两个列表与一个列表作为用户输入进行比较总是触发标志
- java - 从java中的n个字符串用户输入中提取单词