首页 > 解决方案 > TensorFlow Dataset API 中的条件语句

问题描述

我已经使用 Tensorflow 数据集 API 构建了一个数据管道,但我希望一些操作(如洗牌)取决于我是在迭代训练数据集还是测试数据集。我想知道是否有办法在数据集 API 管道中使用条件语句?我尝试了以下代码,但它说它无法将类型对象转换ShuffleDataset为张量。

# This is the placeholder I feed with proper file name depending on whether I'm training or testing
filenames_placeholder = tf.placeholder(tf.string, shape = (None), name = 'filenames_placeholder')

# This it the placeholder I would like to feed with True/False to influence shuffling
shuffle = tf.placeholder(tf.bool, shape = (None), name = 'shuffle')

dataset = tf.data.TFRecordDataset(self.filenames_placeholder)
dataset = dataset.map(lambda x: parse(x), num_parallel_calls = 4)

# The following does not work
def shuffle_true():
    return dataset.shuffle(buffer_size = 1024)
def shuffle_false():
    return dataset
dataset = tf.cond(self.shuffle, shuffle_true, shuffle_false)

标签: tensorflow

解决方案


你可以定义一个函数

def tr_input_fn(filename, mode):
    dataset = tf.data.TFRecordDataset(filename)
    if mode == 'Train':
        dataset = dataset.shuffle()
        dataset = dataset.map(map_func)
        return dataset
    return dataset

据我所知,数据集 api 中现在有明确的条件语句。 https://www.tensorflow.org/api_docs/python/tf/data/Dataset


推荐阅读