tensorflow - TensorFlow Dataset API 中的条件语句
问题描述
我已经使用 Tensorflow 数据集 API 构建了一个数据管道,但我希望一些操作(如洗牌)取决于我是在迭代训练数据集还是测试数据集。我想知道是否有办法在数据集 API 管道中使用条件语句?我尝试了以下代码,但它说它无法将类型对象转换ShuffleDataset
为张量。
# This is the placeholder I feed with proper file name depending on whether I'm training or testing
filenames_placeholder = tf.placeholder(tf.string, shape = (None), name = 'filenames_placeholder')
# This it the placeholder I would like to feed with True/False to influence shuffling
shuffle = tf.placeholder(tf.bool, shape = (None), name = 'shuffle')
dataset = tf.data.TFRecordDataset(self.filenames_placeholder)
dataset = dataset.map(lambda x: parse(x), num_parallel_calls = 4)
# The following does not work
def shuffle_true():
return dataset.shuffle(buffer_size = 1024)
def shuffle_false():
return dataset
dataset = tf.cond(self.shuffle, shuffle_true, shuffle_false)
解决方案
你可以定义一个函数
def tr_input_fn(filename, mode):
dataset = tf.data.TFRecordDataset(filename)
if mode == 'Train':
dataset = dataset.shuffle()
dataset = dataset.map(map_func)
return dataset
return dataset
据我所知,数据集 api 中现在有明确的条件语句。 https://www.tensorflow.org/api_docs/python/tf/data/Dataset
推荐阅读
- php - Cannot copy text correctly due to repeating Ajax call
- ansible - Ansible 连接失败:未知类型
- finance - 无法使用带有 zipline 数据框的 pandas excel 编写器?
- mallet - Mallet 超参数优化
- python - 如何从 Python 中的双/三元组的输出中删除列表特殊字符(“()”、“'”、“”)
- .net - 在进行 CI 时,我可以检测到 Nuget 包的来源是否发生了变化?
- jquery - 提交后用jQuery编辑的值为空
- mysql - MySql insert into select query 复制1亿行太慢
- tmux - ^R 在 iTerm 中使用 tmux 和 zshell 杀戮窗格,不显示历史记录
- java - javafx i18n:“未指定资源”异常