首页 > 解决方案 > 将 >2GB 数据传递给 tf.estimator

问题描述

我有x_trainnumpyy_train数组,每个数组 >2GB。我想使用 tf.estimator API 训练模型,但出现错误:

ValueError: Cannot create a tensor proto whose content is larger than 2GB

我正在使用以下方式传递数据:

def input_fn(features, labels=None, batch_size=None,
             shuffle=False, repeats=False):
    if labels is not None:
        inputs = (features, labels)
    else:
        inputs = features
    dataset = tf.data.Dataset.from_tensor_slices(inputs)
    if shuffle:
        dataset = dataset.shuffle(shuffle)
    if batch_size:
        dataset = dataset.batch(batch_size)
    if repeats:
        # if False, evaluate after each epoch
        dataset = dataset.repeat(repeats)
    return dataset

train_spec = tf.estimator.TrainSpec(
    lambda : input_fn(x_train, y_train,
                      batch_size=BATCH_SIZE, shuffle=50),
    max_steps=EPOCHS
)

eval_spec = tf.estimator.EvalSpec(lambda : input_fn(x_dev, y_dev))

tf.estimator.train_and_evaluate(model, train_spec, eval_spec)

tf.data 文档提到了这个错误,并提供了使用带有占位符的传统 TenforFlow API 的解决方案。不幸的是,我不知道如何将其转换为 tf.estimator API?

标签: pythontensorflowtensorflow-datasetstensorflow-estimator

解决方案


对我有用的解决方案是使用

tf.estimator.inputs.numpy_input_fn(x_train, y_train, num_epochs=EPOCHS,
                                   batch_size=BATCH_SIZE, shuffle=True)

而不是input_fn. 唯一的问题是会tf.estimator.inputs.numpy_input_fn引发弃用警告,所以不幸的是这也将停止工作。


推荐阅读