python - 使用 tf.data.Datasets 冻结 Tensorflow 图时确定输入节点
问题描述
我使用 Tensorflow tf.data.Dataset
API 作为输入管道,如下所示:
train_dataset = tf.data.Dataset.from_tensor_slices((trn_X,trn_y))
train_dataset =
train_dataset.map(_trn_parse_function,num_parallel_calls=12)
train_dataset =
train_dataset.shuffle(buffer_size=1000).repeat(args.num_epochs)#
.batch(args.batch_size)
train_dataset = train_dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
train_dataset = train_dataset.prefetch(buffer_size=600)
val_dataset = tf.data.Dataset.from_tensor_slices((val_X,val_y))
val_dataset = val_dataset.map(_val_parse_function,num_parallel_calls=4)
val_dataset = val_dataset.repeat(1)
val_dataset = val_dataset.apply(tf.contrib.data.batch_and_drop_remainder(args.batch_size))
val_dataset = val_dataset.prefetch(buffer_size=200)
handle = tf.placeholder(tf.string, shape=[])
iterator = tf.data.Iterator.from_string_handle(
handle, train_dataset.output_types,
train_dataset.output_shapes)
images,labels = iterator.get_next()
train_iter = train_dataset.make_initializable_iterator()
val_iter = val_dataset.make_initializable_iterator()
然后使用此代码在训练和验证数据集之间切换:
# Define training and validation handlers
training_handle = sess.run(train_iter.string_handle())
validation_handle = sess.run(val_iter.string_handle())
sess.run(train_iter.initializer)
sess.run(val_iter.initializer)
...
loss = sess.run([train_op],feed_dict={handle:training_handle,
is_training:True})
训练后,我保存权重,然后将保存的 checkpoint((.meta) 中的图形冻结为.pb
格式。随后,运行optimize_for_inference.py
tensorflow repo 中提供的工具。此脚本需要input_nodes_names
定义。我无法确定哪个是图的正确输入节点。这是我的图的节点:
['Variable/initial_value',
'Variable',
'Variable/Assign',
'Variable/read',
'increment_global_step/value',
'increment_global_step',
'Placeholder',
'is_training',
'tensors/component_0',
'tensors/component_1',
'num_parallel_calls',
'batch_size',
'count',
'buffer_size',
'OneShotIterator',
'IteratorToStringHandle',
'IteratorGetNext',
....
....
'output/Softmax]
可以很容易地确定输出节点,但不能确定输入节点。
解决方案
handle = tf.placeholder(tf.string, shape=[]) 是您的输入,因此张量很可能是“Placeholder:0”。
但是,这样写会更有意义:
handle = tf.placeholder(tf.string, shape=[], name="input_placeholder")
那你肯定知道。
推荐阅读
- elasticsearch - 是否可以在 Elastic Search 中使用简单的查询字符串查询来设置 TYPE 参数
- jupyter-notebook - jupyter notebook (macOS) 中的 %conda list biopython 错误
- java - 无法加载类“com.intellij.util.Consumer”
- javascript - npm 不支持 Node.js vX.XX(当前)
- php - Paypal 捐赠返回 url 给出了一个空数组
- android-gradle-plugin - Android lint AnnotationProcessorOnCompilePath 与 ButterKnife
- c# - 以编程方式使 UIView 高度包装其内容?
- swift - 对齐文本 Swiftui
- r - 如何从 r 中的箱线图中获取值(例如中位数)?
- sql - Proc oracle 问题 - 调用已执行但没有插入 (ORACLE/PLSQL)