python - 从使用 Tensorflow 数据集 API 训练的模型推断新输入
问题描述
我从 Dataset API 训练一个 tensorflow (1.7) 模型,如下所示:
features_data_ph = tf.placeholder(tf.int32, [None, None, max_sent_len], 'features_data_ph')
mode_ph = tf.placeholder(tf.int32, name='mode_ph')
labels_data_ph = tf.placeholder(tf.int32, [None, num_classes], 'labels_data_ph')
train_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
train_dataset = train_dataset.shuffle(buffer_size=100000).batch(batch_size)
train_iterator = train_dataset.make_initializable_iterator()
val_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
val_iterator = val_dataset.make_initializable_iterator()
input_tensor, labels_tensor = tf.case(
{
tf.equal(mode_ph, 0): train_iter.get_next,
tf.equal(mode_ph, 1): val_iter.get_next,
}
)
logits = model(input_tensor)
loss = get_loss(logits, labels_tensor)
...
# start of training epoch
session.run(train_iterator.initializer, feed_dict={
features_data_ph: train_features,
labels_data_ph: train_labels
})
...
# new validation after some steps
session.run(val_iterator.initializer, feed_dict={
features_data_ph: val_features,
labels_data_ph: val_labels
})
现在如您所见,input_tensor
取决于数据集。所以我不能只提供一个新的 numpy 数组来推断不在数据集中的数据。
到目前为止我所做的是创建第三个数据集,用于保存推理数据(并添加tf.equal(mode_ph, 2): infer_iter.get_next
到tf.case
)
有没有更好的方法来推断现有数据集中没有的数据?使用val_dataset
会覆盖它包含的数据
解决方案
推荐阅读
- ms-access - 如何将子表单中一个字段的行值保存到主表单记录?- 使用权
- reactjs - React CSS 模块 - 无效的优先顺序
- reactjs - localStorage 获取未定义状态
- sql - 在 Google Big Query 中执行简单分组
- c++ - 字符串/(字符串向量)匹配的快速算法
- c# - 无法将子条目分配给数据库表中的父条目
- javascript - 主函数一个接一个地调用其他函数
- c++ - 不能打开文件; 输出未更改错误
- python - 如何从 Python 中的字典中获取每个 URL?
- node.js - 一次只处理 500 行/行 createReadStream