python - 大尺寸输入的Tensorflow Estimator Graph Size Limitation
问题描述
我认为我的整个训练数据都存储在达到 2gb 限制的图表中。如何在估算器 API 中使用 feed_dict?仅供参考,我正在使用 tensorflow 估计器 API 来训练我的模型。
输入功能:
def input_fn(X_train,epochs,batch_size):
''' input X_train is the scipy sparse matrix of large input dimensions(200000) and number of rows=600000'''
X_train_tf = tf.data.Dataset.from_tensor_slices((convert_sparse_matrix_to_sparse_tensor(X_train, tf.float32)))
X_train_tf = X_train_tf.apply(tf.data.experimental.shuffle_and_repeat(shuffle_to_batch*batch_size, epochs))
X_train_tf = X_train_tf.batch(batch_size).prefetch(2)
return X_train_tf
错误:
回溯(最后一次调用):文件“/tmp/apprunner/.working/runtime/app/ae_python_tf.py”,第 259 行,在 AE_Regressor.train(lambda: input_fn(X_train,epochs,batch_size), hooks=[time_hist , logging_hook]) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 354 行,在 train loss = self._train_model(input_fn, hooks, save_listeners) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1205 行,_train_model return self._train_model_distributed(input_fn, hooks , save_listeners) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1352 行,在 _train_model_distributed Saving_listeners) 文件“/tmp/ apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/estimator/estimator.py”,第 1468 行,在 _train_with_estimator_spec log_step_count_steps=log_step_count_steps) 作为 mon_sess:文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 504 行,在 MonitoredTrainingSession stop_grace_period_secs =stop_grace_period_secs)文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 921 行,在init stop_grace_period_secs=stop_grace_period_secs) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/monitored_session.py”,第 631 行,在init h.begin()文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/training/basic_session_run_hooks.py”,第 543 行,开始 self._summary_writer = SummaryWriterCache.get(self. _checkpoint_dir) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer_cache.py”,第 63 行,在获取 logdir,graph=ops。 get_default_graph()) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,第 367 行,在init 超级(FileWriter,自我)。init (event_writer, graph, graph_def) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,第 83 行,在init self.add_graph(graph=graph, graph_def=graph_def) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/summary/writer/writer.py”,行193、在add_graph true_graph_def = graph.as_graph_def(add_shapes=True) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py”中,行3124,在 as_graph_def 结果中,_ = self._as_graph_def(from_version, add_shapes) 文件“/tmp/apprunner/.working/runtime/env/lib/python3.5/site-packages/tensorflow/python/framework/ops.py” ,第 3082 行,在 _as_graph_def c_api.TF_GraphToGraphDef(self._c_graph, buf) tensorflow.python.framework.errors_impl.InvalidArgumentError:无法序列化 tensorflow 类型的协议缓冲区。GraphDef 作为序列化大小(2838040852 字节)将大于限制(2147483647 字节)
解决方案
我通常反对逐字引用文档,但这在TF 文档中逐字解释,我找不到比他们已经做的更好的方法:
请注意,[使用
Dataset.from_tensor_slices()
onfeatures
和labels
numpy 数组] 会将特征和标签数组作为 tf.constant() 操作嵌入到您的 TensorFlow 图中。这适用于小型数据集,但会浪费内存——因为数组的内容将被复制多次——并且可能会遇到 tf.GraphDef 协议缓冲区的 2GB 限制。作为替代方案,您可以根据 tf.placeholder() 张量定义数据集,并在对数据集初始化迭代器时提供 NumPy 数组。
# Load the training data into two NumPy arrays, for example using `np.load()`.
with np.load("/var/data/training_data.npy") as data:
features = data["features"]
labels = data["labels"]
features_placeholder = tf.placeholder(features.dtype, features.shape)
labels_placeholder = tf.placeholder(labels.dtype, labels.shape)
dataset = tf.data.Dataset.from_tensor_slices((features_placeholder, labels_placeholder))
# [Other transformations on `dataset`...]
dataset = ...
iterator = dataset.make_initializable_iterator()
sess.run(iterator.initializer, feed_dict={features_placeholder: features,
labels_placeholder: labels})
(代码和文本均取自上面的链接,删除assert
了与问题无关的代码)
更新
如果您尝试将它与 Estimator API 一起使用,那么您就不走运了。从同一个链接页面,上面引用的几个部分:
注意:目前,一次性迭代器是唯一可以轻松与 Estimator 一起使用的类型。
正如您在评论中指出的那样,这是因为 Estimator API 隐藏了sess.run()
您需要feed_dict
为迭代器传递的调用。
推荐阅读
- azure - 如何获取 Azure 虚拟机的“上次登录/访问”时间
- python - 检查数据框中是否存在值并在不使用列名的情况下获取其索引
- sharepoint - 无法在 Sharepoint 设计器 2013 中创建工作流
- java - 在 RxJava 调用后从方法返回字符串
- puppeteer - 为什么在使用 puppeteer-extra 插件和 puppeteer 时隐身模式不起作用
- android - 如何在 React-Native 中在没有 Internet 的情况下通过 Wifi 发送和接收数据
- java - Java中对象映射器的使用
- ruby-on-rails - 应用范围进行过滤时的异常行为
- javascript - Wordpress Gutenberg 文件按钮 - 如何首先显示?
- javascript - 如何使用我创建的函数验证表单