首页 > 解决方案 > 在tensorflow中,如何枚举训练数据(对比pytorch)

问题描述

在 pytorch 中,这就是我枚举训练数据的方式。

for epoch in range(0, args.epoches):
    for i, batch in enumerate(train_data):
        model.update(batch)

train_data包含多个batches 并且批次正在被枚举并更新模型,这对我来说非常清楚。


我认为这是 tensorflow 如何处理批次的基本示例。

for step in range(num_steps):
    batch_data, batch_labels = generate_batch(batch_size, num_skips, skip_window)
    feed_dict = {train_dataset : batch_data, train_labels : batch_labels}
    _, l = session.run([optimizer, loss], feed_dict=feed_dict)

也许这是一个非常明显的问题,但我不清楚session.run在 tensorflow 中如何处理枚举训练批次。我找不到批次在代码​​中循环。我所看到的是feed_dict,我假设它处理循环。

有人可以对此有所了解吗?

标签: tensorflowpytorch

解决方案


TensorFlow 有一个History用于此目的的对象。您将History对象作为model.fit()方法的返回。

History对象及其属性是连续时期的History.history训练损失值和度量值的记录,以及验证损失值和验证度量值(如果适用)。

希望这是你需要的。


推荐阅读