tensorflow2.0 - tf.data.Dataset.from_generator 调用的复杂对象的 output_signature 、 output_types 和 output_shapes 示例
问题描述
我有一个生成器函数,它产生以下元组:yield (transformed_input_array, set_y)
transform_input_array是具有以下形状的 ndarray 列表:(1024, 104), (1024, 142), (1024, 1), (1024, 1), (1024, 1), (1024, 1), (1024, 140) 和以下类型:tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64 set_y是一个形状为1024和类型为int64的 ndarray
我已经用 tf.data.Dataset.from_generator 函数包装了我的生成器,这里是代码:
dataset = tf.data.Dataset.from_generator(
generator,
# output_signature=(
# tf.TensorSpec(shape=(), dtype=(tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64)),
# tf.TensorSpec(shape=1024, dtype=tf.int64))
output_types=(tf.float64, tf.float64, tf.int8, tf.int16, tf.int8, tf.int8, tf.float64, tf.int64),
output_shapes=((1024, 104), (1024, 142), (1024, 1), (1024, 1), (1024, 1), (1024, 1), (1024, 140), 1024)
)
但是当我进行培训时,我收到以下错误:
ValueError: Data is expected to be in format
x
,(x,)
,(x, y)
, or(x, y, sample_weight)
, found: (<tf.Tensor 'IteratorGetNext:0' shape=(1024, 104) dtype=float64>, <tf.Tensor 'IteratorGetNext:1' shape=( 1024, 142) dtype=float64>, <tf.Tensor 'IteratorGetNext:2' shape=(1024, 1) dtype=int8>, <tf.Tensor 'It eratorGetNext:3' shape=(1024, 1) dtype=int16 >, <tf.Tensor 'IteratorGetNext:4' shape=(1024, 1) dtype=int8>, <tf.Tensor 'IteratorGetNext:5' shape=(1024, 1) dtype=int8>, <tf.Tensor 'IteratorGetNext :6' shape=(1024, 140) dtype=float64>, <tf.Tensor 'ExpandDims:0' shape=(1024, 1) dtype=int64>)
如果我尝试使用 output_signature 参数(注释掉的代码)运行,我会收到以下错误:
TypeError:无法将值(tf.float64、tf.float64、tf.int8、tf.int16、tf.int8、tf.int8、tf.float64)转换为 TensorFlow DType。
有人可以提供一个示例,说明我应该如何处理复杂类型(ndarrays 列表)?在 TF 文档中找不到任何示例..
解决方案
推荐阅读
- amazon-web-services - AWS 中 terraform 资源的强制标记
- vue.js - 在 VueJS 组件中将样式应用于 Flatpickr
- macos - Mac 上的 Perl,cpan,不会安装
- tensorflow - 借助 GPU 支持对高维数据进行更快的 Kmeans 聚类
- python-3.x - 替换 HTML 标签不会改变数据框
- python - 按某些列比较两个文本文件然后返回整行?
- css - 具有按行分组的数据的 CSS 网格
- angular - 如何在 Angular 中处理区分大小写。将数据从 FormControl 绑定到模型
- c++ - 正确地将数组传递给函数 c++
- c++ - 如何从 com 端口获取正确的值