tensorflow - 应用地图后如何保持 tf.dataset 的形状?
问题描述
我正在创建一个包含 2 个图像作为输入和一个掩码作为目标的 tf.dataset 对象。它们都是 3D 的。应用自定义地图后,对象的形状从 <RepeatDataset shapes: (((), ()), ()), types: ((tf.string, tf.string), tf.string)>
变为<PrefetchDataset shapes: (<unknown>, <unknown>, <unknown>), types: (tf.float32, tf.float32, tf.int32)>
,当我拟合数据时,我的模型会抛出错误,因为它只检测到一个输入而不是 2。这是我正在做的事情:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)
我无法分别映射图像和蒙版,因为它们需要相同的种子来进行随机裁剪和翻转。是否可以在地图之后更改形状,以便我可以将其提供给我的 2 输入模型?
编辑 :
我的 random_crop_flip 函数如下:
def random_crop_flip(images, mask, width=128, height=128, depth=128):
img_bl, img_fu = images
img_bl = img_bl.numpy()
img_fu = img_fu.numpy()
mask = mask.numpy()
x_rand = random.randint(0, img_bl.shape[2] - width)
y_rand = random.randint(0, img_bl.shape[1] - height)
z_rand = random.randint(0, img_bl.shape[3] - depth)
img_bl_f = img_bl[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
img_fu_f = img_fu[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
mask_f = mask[:, y_rand:y_rand + height, x_rand:x_rand + width, z_rand:z_rand + depth, :]
flip_x = random.choice([True, False])
flip_y = random.choice([True, False])
flip_z = random.choice([True, False])
if flip_x:
img_bl_f = np.flip(img_bl_f, axis=2)
img_fu_f = np.flip(img_fu_f, axis=2)
mask_f = np.flip(mask_f, axis=2)
if flip_y:
img_bl_f = np.flip(img_bl_f, axis=1)
img_fu_f = np.flip(img_fu_f, axis=1)
mask_f = np.flip(mask_f, axis=1)
if flip_z:
img_bl_f = np.flip(img_bl_f, axis=3)
img_fu_f = np.flip(img_fu_f, axis=3)
mask_f = np.flip(mask_f, axis=3)
images = zip(img_bl_f, img_fu_f)
return images, mask_f
拉链没有解决我的问题。是否可以修改返回以获得我想要的输出?
解决方案
我已经设法通过“展平”(消除括号)random_crop_flip 的返回,并在它们之上应用另一个地图来解决这个问题,我在其中指定了形状并返回了我想要的结构(x,y),z:
def _set_shapes(img_bl, img_fu, mask):
img_bl.set_shape([128, 128, 128, 1])
img_fu.set_shape([128, 128, 128, 1])
mask.set_shape([128, 128, 128, 1])
return (img_bl, img_fu), mask
然后我的代码如下所示:
x, y = get_filenames(train_data_path, img_type='FLAIR')
z = get_filenames(train_data_path, img_type='mask')
path_dataset = tf.data.Dataset.from_tensor_slices((x, y))
mask_dataset = tf.data.Dataset.from_tensor_slices(z)
dataset = tf.data.Dataset.zip((path_dataset, mask_dataset)).shuffle(50).repeat(10)
ds = dataset. \
map(lambda xx, zz: ((tf.py_function(load, [xx], [tf.float32, tf.float32])),
tf.py_function(load_mask, [zz], [tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(lambda xx, zz: (tf.py_function(random_crop_flip, [xx, zz],
[tf.float32, tf.float32, tf.int32])),
num_parallel_calls=tf.data.AUTOTUNE)
ds = ds.map(_set_shapes)
ds = ds.batch(2)
ds = ds.prefetch(tf.data.AUTOTUNE)