首页 > 解决方案 > 逐批馈送 tf.estimator.Estimator.predict

问题描述

我有一个训练有素的估计器模型,我需要获取一个不适合内存的非常大数据集的预测向量,处理这些预测向量并保存它们。到目前为止,我的代码看起来像这样:

def hist(predictions):
    ...
    return histograms

def input_fn(feat, batch_size=100):
    dataset = tf.data.Dataset.from_tensor_slices((feat))
    dataset = dataset.batch(batch_size)
    dataset = dataset.map(lambda x:...)
    return dataset

super_batch = 100
splits = data.shape[0]//super_batch

for s in range(splits):
    pred = list(classifier.predict(lambda: input_fn(data[s*super_batch:(s+1)*super_batch])))
    pred_cls = [p["classes"] for p in pred]
    hist_vec = hist(pred_cls)
    save hist_vec

我知道这不是正确的方法,因为它使 GPU 长时间处于空闲状态,并且由于每次调用分类器时都会加载模型。预测需要很长时间才能运行。有什么方法可以使用带有估计器的 feed 函数并加快这个过程?

标签: pythontensorflowtensorflow-datasetstensorflow-estimator

解决方案


我假设问题出在tf.data.Dataset.from_tensor_slices()


如果您在 Tensorflow 1.0 中使用numpy 数组并禁用了急切模式,它会将值作为一个或多个操作tf.data.Dataset.from_tensor_slices()嵌入到图中。tf.constant你的数据集越大,图表就越大。这是非常低效的,您可能会遇到ValueError: GraphDef cannot be larger than 2GB错误。

使用tf.estimator,您有 2 个解决方案:

  • 使用tf.data.Dataset.from_generator. 只需转换feat为 Python 生成器即可。from_tensor_slices由于tf.data图形的速度受到 Python 运行时的限制,因此性能要差一些。

  • tf.data.Dataset.from_tensor_slices()与 TensorFlow 占位符一起使用。这是更复杂但最有效的一种。请在此处查看我的答案以获取更多信息。它的要点是,您需要创建一个特定的钩子来在估计器内创建会话后初始化占位符。


推荐阅读