首页 > 解决方案 > Keras 实验性 RandomFlip 和 RandomRotation 不适用于地图

问题描述

这段代码会产生一个我不明白的错误。有人可以解释一下吗?

import tensorflow as tf

def augment(img):
    data_augmentation = tf.keras.Sequential([
              tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
              tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
             ])
    img = tf.expand_dims(img, 0)
    return data_augmentation(img)

# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data)

# and augment... -> bug
dataset = dataset.map(augment)

# note that the follwing works
for im in dataset:
   augment(im)

和一个得到

ValueError: Tensor-typed variable initializers must either be wrapped in an init_scope or callable (e.g., `tf.Variable(lambda : tf.truncated_normal([10, 40]))`) when building functions. Please file a feature request if this restriction inconveniences you.

我在 Google Colab 上尝试过,并在我的计算机上安装了 Tensorflow 2.4.1。请注意,通过调整大小或重新缩放它可以工作(就像在这个例子中一样https://www.tensorflow.org/tutorials/images/data_augmentation但他们没有尝试使用 RandomRotate 即使他们在循环中使用它)。

标签: tensorflowkerasdata-augmentation

解决方案


这是答案...

import numpy as np
import tensorflow as tf

data_augmentation = tf.keras.Sequential([
              tf.keras.layers.experimental.preprocessing.RandomFlip('horizontal'),
              tf.keras.layers.experimental.preprocessing.RandomRotation(0.2),
             ])

# generate 10 images 8x8 RGB
data = np.random.randint(0,255,size=(10, 8, 8, 3))
dataset = tf.data.Dataset.from_tensor_slices(data).batch(5)

# and augment... -> bug
dataset = dataset.map(lambda x: data_augmentation(x))

奇怪,如果我们使用 lambda 函数,它就可以工作,如果我们定义一个只调用data_augmentation它的函数就会失败......


推荐阅读