首页 > 解决方案 > tf.data.Dataset.from_generator 调用的复杂对象的 output_signature 、 output_types 和 output_shapes 示例

问题描述

我有一个生成器函数,它产生以下元组:yield (transformed_input_array, set_y)

transform_input_array是具有以下形状的 ndarray 列表:(1024, 104), (1024, 142), (1024, 1), (1024, 1), (1024, 1), (1024, 1), (1024, 140) 和以下类型:tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64 set_y是一个形状为1024和类型为int64的 ndarray

我已经用 tf.data.Dataset.from_generator 函数包装了我的生成器,这里是代码:

dataset = tf.data.Dataset.from_generator(
    generator,
    # output_signature=(
    #     tf.TensorSpec(shape=(), dtype=(tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64)),
    #     tf.TensorSpec(shape=1024, dtype=tf.int64))
    output_types=(tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64, tf.int64),
    output_shapes=((1024, 104), (1024, 142), (1024, 1), (1024, 1), (1024, 1), (1024, 1), (1024, 140), 1024)
)

但是当我进行培训时,我收到以下错误:

ValueError: Data is expected to be in format x, (x,), (x, y), or (x, y, sample_weight), found: (<tf.Tensor 'IteratorGetNext:0' shape=(1024, 104) dtype=float64>, <tf.Tensor 'IteratorGetNext:1' shape=( 1024, 142) dtype=float64>, <tf.Tensor 'IteratorGetNext:2' shape=(1024, 1) dtype=int8>, <tf.Tensor 'It eratorGetNext:3' shape=(1024, 1) dtype=int16 >, <tf.Tensor 'IteratorGetNext:4' shape=(1024, 1) dtype=int8>, <tf.Tensor 'IteratorGetNext:5' shape=(1024, 1) dtype=int8>, <tf.Tensor 'IteratorGetNext :6' shape=(1024, 140) dtype=float64>, <tf.Tensor 'ExpandDims:0' shape=(1024, 1) dtype=int64>)

如果我尝试使用 output_signature 参数(注释掉的代码)运行,我会收到以下错误:

TypeError:无法将值(tf.float64、tf.float64、tf.int8、tf.int16、tf.int8、tf.int8、tf.float64)转换为 TensorFlow DType。

有人可以提供一个示例,说明我应该如何处理复杂类型(ndarrays 列表)?在 TF 文档中找不到任何示例..

标签: tensorflow2.0tensorflow-datasets

解决方案


推荐阅读