首页 > 解决方案 > 从使用 Tensorflow 数据集 API 训练的模型推断新输入

问题描述

我从 Dataset API 训练一个 tensorflow (1.7) 模型,如下所示:

features_data_ph = tf.placeholder(tf.int32, [None, None, max_sent_len], 'features_data_ph')

mode_ph = tf.placeholder(tf.int32, name='mode_ph')

labels_data_ph = tf.placeholder(tf.int32, [None, num_classes], 'labels_data_ph')

train_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
train_dataset = train_dataset.shuffle(buffer_size=100000).batch(batch_size)
train_iterator = train_dataset.make_initializable_iterator()

val_dataset = tf.data.Dataset.from_tensor_slices((features_data_ph, labels_data_ph))
val_iterator = val_dataset.make_initializable_iterator()

input_tensor, labels_tensor = tf.case(
                {
                    tf.equal(mode_ph, 0): train_iter.get_next,
                    tf.equal(mode_ph, 1): val_iter.get_next,
                }
            )

logits = model(input_tensor)
loss = get_loss(logits, labels_tensor)
...
# start of training epoch
session.run(train_iterator.initializer, feed_dict={
    features_data_ph: train_features,
    labels_data_ph: train_labels
})
...
# new validation after some steps
session.run(val_iterator.initializer, feed_dict={
    features_data_ph: val_features,
    labels_data_ph: val_labels
})

现在如您所见,input_tensor取决于数据集。所以我不能只提供一个新的 numpy 数组来推断不在数据集中的数据。

到目前为止我所做的是创建第三个数据集,用于保存推理数据(并添加tf.equal(mode_ph, 2): infer_iter.get_nexttf.case

有没有更好的方法来推断现有数据集中没有的数据?使用val_dataset会覆盖它包含的数据

标签: pythontensorflow

解决方案


推荐阅读