tensorflow - Keras fit() 与 train_on_batch() 损失差异
问题描述
我计划用多个数组训练一个神经网络,我知道这是不可能的fit()
。因此,我尝试fit()
使用train_on_batch()
. 在训练了一个非常简单的 Keras 顺序模型之后,我观察到比较fit()
和我自己的train_on_batch()
循环时损失的巨大差异。这是epochs
和batch_size
等于1,我认为这是最简单的测试用例。
在下面的示例中,ab
是一个 ndarray,其中包含来自上一次运行的 float32 幅度数据rfft()
。labels
是一个包含整数的 ndarray,0 或 1 用于二进制分类。
ab = np.load(absName)
ab2 = np.transpose(ab)
labels = np.load(labelName)
l = ab.shape[1]
fftSize = ab.shape[0]
assert l == len(labels)
print("ab.shape =", ab.shape)
print("ab2.shape =", ab2.shape)
print("labels.shape =", labels.shape)
print("l =", l)
print("fftSize =", fftSize)
model = Sequential()
model.add(Dense(1024, input_shape=(fftSize,)))
model.add(Activation("relu"))
model.add(Dense(1))
model.add(Activation("sigmoid"))
model.compile(
optimizer="rmsprop",
loss="binary_crossentropy",
metrics=["accuracy"])
epochs = 1
batchSize = 1
useFit = True
if useFit:
model.fit(ab2, labels, batch_size=batchSize, epochs=epochs, verbose=0)
else:
for i in range(0, epochs):
j = 0
while (j + batchSize) < l:
model.train_on_batch(ab2[j:j+batchSize], labels[j:j+batchSize])
j += batchSize
if j < l:
model.train_on_batch(ab2[j:l], labels[j:l])
score = model.evaluate(ab2, labels, verbose=1)
for i in range(0, len(score)):
print("score[" + model.metrics_names[i] + "] = " + str(score[i]))
所有情况下的打印输出:
ab.shape = (513, 168)
ab2.shape = (168, 513)
labels.shape = (168,)
l = 168
fftSize = 513
如果useFit
为 True,则打印的分数为:
score[loss] = 0.36022053304172696
score[acc] = 0.8809523809523809
如果useFit
为 False,则打印的分数为:
score[loss] = 0.49978475148479146
score[acc] = 0.8809523809523809
这是损失的很大差异。如果我尝试 10 个 epoch 和 32 个批量大小,它们都会产生约 35 的损失。不过,我不确定这是一个可靠的实验,因为我知道在不同时期之间会发生洗牌。
我的印象是我的自定义训练循环完全fit()
可以做(忽略改组为epochs=1
),但我一定遗漏了一些东西。有任何想法吗?
解决方案
推荐阅读
- python - 如果 dropout rate > 0 使用 tf.cond() 尝试仅构造 DropoutWrapper 操作时出错
- haskell - 分段可变状态的单子
- javascript - 调整高度由 ng-if 指令动态生成的标签
- javascript - 分离 var app = new Vue({}); 到另一个 app.js
- java - 使用 BottomNavigationView 在它们之间切换时保留片段状态
- amazon-web-services - AWS 中托管的两个应用程序
- google-oauth - 与 Google Actions 帐户关联的自定义错误消息?
- objective-c - 如何使用 completionHandlers 等待多个函数?
- c# - C# WPF 将带有复选框的 json 数据绑定到列表框中
- angular - 导航到第一个孩子时触发父路由 OnInit