首页 > 解决方案 > 数据管道中的批量大小和 midel.fit() 中的批量大小有什么区别?

问题描述

这两个是相同的批量大小,还是具有不同的含义?

BATCH_SIZE=10
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.batch(BATCH_SIZE)

第二

history = model.fit(train_ds,
  epochs=EPOCHS,                    
  validation_data=create_dataset(X_valid, y_valid_bin),
  max_queue_size=1,
  workers=1,
  batch_size=10,
  use_multiprocessing=False)

我遇到了 Ram 无法运行的问题... 训练图像示例 333000 Ram 30GB 12 GB GPU 批量大小应该是多少?

完整代码在这里

标签: tensorflowmachine-learningneural-networkdata-sciencebatchsize

解决方案


数据集(批量大小)

批量大小仅意味着有多少数据将通过您定义的管道。在 Dateset 的情况下,批量大小表示在一次迭代中将有多少数据传递给模型。例如,如果您形成一个数据生成器并将批量大小设置为 8。现在在每个迭代数据生成器上提供 8 个数据记录。

Model.fit(批量大小)

在model.fit中,当我们设置批量大小时,这意味着模型将在传递等于批量大小的数据记录后计算损失。如果您了解深度学习模型,它们将计算前馈的特定损失,而不是通过反向传播,它们会自我改进。现在,如果您在 model.fit 中设置批量大小 8,则现在将 8 条数据记录传递给模型,并根据这 8 条数据记录计算损失,然后模型从该损失中得到改进。

例子:

现在,如果您将 dateset batch-size 设置为 4 并将 model.fit batch-size 设置为 8。现在您的 dateset 生成器必须迭代 2 次才能为模型提供 8 张图像,而 model.fit 仅执行 1 次迭代来计算损失。

公羊问题

你的图像尺寸是多少?尝试减少 batch_size,因为每个 epoch 的步长与 ram 无关,但批量大小是。因为如果您提供 10 个批量大小,则必须将 10 个图像加载到 ram 上进行处理,并且您的 ram 无法同时加载 10 个图像。尝试给批量大小 4 或 2。这可能对你有帮助


推荐阅读