首页 > 解决方案 > 大尺寸输入的Tensorflow Estimator Graph Size Limitation

问题描述

我认为我的整个训练数据都存储在达到 2gb 限制的图表中。如何在估算器 API 中使用 feed_dict?仅供参考,我正在使用 tensorflow 估计器 API 来训练我的模型。

输入功能:

def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''

X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
    X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
    X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
    return X_train_tf

错误:

回溯(最后一次调用):文件“/tmp/apprunner/.working/runtime/app/ae_python_tf.py”,第 259 行,在 AE_Regressor.train(lambda: input_fn(X_train,epochs,batch_size), hooks=[time_hist , logging_hook]) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 354 行,在 train loss = self._train_model(input_fn, hooks, save_listeners) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1205 行,_train_model return self._train_model_distributed(input_fn, hooks , save_listeners) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1352 行,在 _train_model_distributed Saving_listeners) 文件“/tmp/ apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1468 行,在 _train_with_estimator_spec log_step_count_steps=log_step_count_steps) 作为 mon_sess:文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 504 行,在 MonitoredTrainingSession stop_grace_period_secs =stop_grace_period_secs)文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 921 行,在init stop_grace_period_secs=stop_grace_period_secs) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 631 行,在init h.begin()文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py”,第 543 行,开始 self._summary_writer = SummaryWriterCache.get(self. _checkpoint_dir) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_cache.py”,第 63 行,在获取 logdir,graph=ops。 get_default_graph()) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,第 367 行,在init 超级(FileWriter,自我)。init (event_writer, graph, graph_def) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,第 83 行,在init self.add_graph(graph=graph, graph_def=graph_def) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,行193、在add_graph true_graph_def = graph.as_graph_def(add_shapes=True) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”中,行3124,在 as_graph_def 结果中,_ = self._as_graph_def(from_version, add_shapes) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py” ,第 3082 行,在 _as_graph_def c_api.TF_GraphToGraphDef(self._c_graph, buf) tensorflow.python.framework.errors_impl.InvalidArgumentError:无法序列化 tensorflow 类型的协议缓冲区。GraphDef 作为序列化大小(2838040852 字节)将大于限制(2147483647 字节)

标签: pythontensorflowtensorflow-estimator

解决方案


我通常反对逐字引用文档,但这在TF 文档中逐字解释,我找不到比他们已经做的更好的方法:

请注意,[使用Dataset.from_tensor_slices()onfeatureslabelsnumpy 数组] 会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中。这适用于小型数据集,但会浪费内存——因为数组的内容将被复制多次——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。

作为替代方案,您可以根据 tf.placeholder() 张量定义数据集,并在对数据集初始化迭代器时提供 NumPy 数组。

# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
  features = data["features"]
  labels = data["labels"]

features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)

dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()

sess.run(iterator.initializer, feed_dict={features_placeholder: features,
                                          labels_placeholder: labels})

(代码和文本均取自上面的链接,删除assert了与问题无关的代码)


更新

如果您尝试将它与 Estimator API 一起使用,那么您就不走运了。从同一个链接页面,上面引用的几个部分:

注意:目前,一次性迭代器是唯一可以轻松与 Estimator 一起使用的类型。

正如您在评论中指出的那样,这是因为 Estimator API 隐藏了sess.run()您需要feed_dict为迭代器传递的调用。


推荐阅读