tensorflow - 数据管道中的批量大小和 midel.fit() 中的批量大小有什么区别?
问题描述
这两个是相同的批量大小,还是具有不同的含义?
BATCH_SIZE=10
dataset = tf.data.Dataset.from_tensor_slices((filenames, labels))
dataset = dataset.batch(BATCH_SIZE)
第二
history = model.fit(train_ds,
epochs=EPOCHS,
validation_data=create_dataset(X_valid, y_valid_bin),
max_queue_size=1,
workers=1,
batch_size=10,
use_multiprocessing=False)
我遇到了 Ram 无法运行的问题... 训练图像示例 333000 Ram 30GB 12 GB GPU 批量大小应该是多少?
解决方案
数据集(批量大小)
批量大小仅意味着有多少数据将通过您定义的管道。在 Dateset 的情况下,批量大小表示在一次迭代中将有多少数据传递给模型。例如,如果您形成一个数据生成器并将批量大小设置为 8。现在在每个迭代数据生成器上提供 8 个数据记录。
Model.fit(批量大小)
在model.fit中,当我们设置批量大小时,这意味着模型将在传递等于批量大小的数据记录后计算损失。如果您了解深度学习模型,它们将计算前馈的特定损失,而不是通过反向传播,它们会自我改进。现在,如果您在 model.fit 中设置批量大小 8,则现在将 8 条数据记录传递给模型,并根据这 8 条数据记录计算损失,然后模型从该损失中得到改进。
例子:
现在,如果您将 dateset batch-size 设置为 4 并将 model.fit batch-size 设置为 8。现在您的 dateset 生成器必须迭代 2 次才能为模型提供 8 张图像,而 model.fit 仅执行 1 次迭代来计算损失。
公羊问题
你的图像尺寸是多少?尝试减少 batch_size,因为每个 epoch 的步长与 ram 无关,但批量大小是。因为如果您提供 10 个批量大小,则必须将 10 个图像加载到 ram 上进行处理,并且您的 ram 无法同时加载 10 个图像。尝试给批量大小 4 或 2。这可能对你有帮助
推荐阅读
- javascript - 使用页面上的 javascript 从 Google Sheet 捕获 ContentService 返回
- python - TypeError:不能将序列乘以“numpy.float64”类型的非整数(复数)
- php - Laravel 队列有时会重复一个作业
- gcc - RISC-V 工具链错误:未知伪操作:`.insn'
- java - 使用 @ResponseBody 将 setter 方法序列化为布尔值的 Spring 序列化
- azure - 使用 Azure 数据工厂基于列合并存储帐户中的两个或多个文件
- swift - 在 Swift 中将 CChar 数组转换为 Hexa 字符串
- macos - 将项目代码从 Ubuntu 主机共享到 macOS 来宾时,错误源文件不是有效的 UTF-8
- javascript - 在 Worker 线程中渲染 Canvas 上下文以实现流畅播放
- java - 为什么不获取映射表数据?