首页 > 解决方案 > Numpy 4d 数组到 tf.data.dataset

问题描述

我正在关注本教程https://www.tensorflow.org/tutorials/generation/pix2pix但我正在尝试制作自己的输入管道。我有一个 4d numpy 数组(Num samples、Height、Width、Channels),我用它ds = tf.data.Dataset.from_tensor_slices()来创建我的数据集。但是,当我调用ds.take(1)它时,它没有批量大小的维度。我可以通过在必要时插入来解决此问题,tf.expand_dims()但我觉得应该有一种方法可以在数据集中执行此操作。

标签: tensorflowtensorflow-datasets

解决方案


你可以试试:

for image in ds.batch(1).take(1):
    assert image.shape[0] == 1
    # do something with the image

推荐阅读