首页 > 解决方案 > 在我使用 dataset,concatenate 之后,dataset.map 只作用于原始数据

问题描述

正如标题所说,我连接两个数据集并使用 map 函数来更改值的位置和重新调整值。在我使用map之前,所有Tensor的形状都是匹配的,但是在使用map函数之后,使用for循环迭代数据集打印索引,迭代的断点在两个数据集的关节上。

我在 Colab 中使用 GPU 遇到了这个问题,并使用 Python 3.6、tensorflow-gpu 2.0.0b1

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, (tf.float32, tf.float32),((7,), (48,48,1)))
dataset = dataset.concatenate(dataset_crop)
dataset = dataset.map(lambda label, img_raw: (tf.cast(img_raw, tf.float32)/float(255), label))
for i,(label, img) in enumerate(dataset):
  print(i)

顺便说一句,在连接之前,数据集的总行有 19984
什么连接地狱..

...
19982
19983
---------------------------------------------------------------------------
InvalidArgumentError                      Traceback (most recent call last)
<ipython-input-26-36305ee0e8ef> in <module>()
----> 1 for i,(label, img) in enumerate(dataset):
      2   print(i)

4 frames
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)

InvalidArgumentError: ValueError: Tensor's shape (7,) is not compatible with supplied shape [48, 48, 1]
Traceback (most recent call last):

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/script_ops.py", line 209, in __call__
    ret = func(*args)

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/dataset_ops.py", line 525, in generator_py_func
    values = next(generator_state.get_iterator(iterator_id))

  File "<ipython-input-25-196a9ac04fc0>", line 5, in img_resize_and_crop_genr
    img.set_shape([side_len, side_len,1])

  File "/usr/local/lib/python3.6/dist-packages/tensorflow/python/framework/ops.py", line 981, in set_shape
    (self.shape, shape))

ValueError: Tensor's shape (7,) is not compatible with supplied shape [48, 48, 1]


     [[{{node PyFunc}}]] [Op:IteratorGetNextSync]

标签: pythonpython-3.xtensorflowtensorflow-datasets

解决方案


问题出在你的from_generator功能上。当您传递output_shapes参数时,会进行严格检查以查看输出形状是否与生成的形状完全相同。在你的情况下,你得到一个ValueError声明,它期待 (48, 48, 1) 但 (7,) 形状已生成。

使用以下代码可以生成类似的错误:

dataset = tf.data.Dataset.from_tensor_slices((np.zeros(19984, dtype=np.float32), np.ones(19984, dtype=np.float32)))

def img_resize_and_crop_genr():
    yield np.zeros((7,)), np.ones((48, 48, 1))

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, (tf.float32, tf.float32),((48,48,1), (7,)))
dataset = dataset.concatenate(dataset_crop)
dataset = dataset.map(lambda label, img_raw: (tf.cast(img_raw, tf.float32)/float(255), label))
for i,(label, img) in enumerate(dataset):
  print(i)

输出:

ValueError: `generator` yielded an element of shape (7,) where an element of shape (48, 48, 1) was expected.

我相信你已经交换了你的output_shapes. 如果是这种情况,您可以将更正为:

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, 
                                              (tf.float32, tf.float32),((7,), (48,48,1)))

此外,output_shapes是一个可选参数。您可以通过不传递参数来整体避免问题,如下所示:

dataset_crop = tf.data.Dataset.from_generator(img_resize_and_crop_genr, 
                                              (tf.float32, tf.float32))

推荐阅读