python - Tensorflow 2 throwing ValueError: as_list() is not defined on an unknown TensorShape
问题描述
我正在尝试在 Tensorflow 2.0 中训练一个 Unet 模型,该模型将图像和分割掩码作为输入,但我得到了一个ValueError : as_list() is not defined on an unknown TensorShape
. 堆栈跟踪显示问题发生在_get_input_from_iterator(inputs)
:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in _prepare_feed_values(model, inputs, mode)
110 for inputs will always be wrapped in lists.
111 """
--> 112 inputs, targets, sample_weights = _get_input_from_iterator(inputs)
113
114 # When the inputs are dict, then we want to flatten it in the same order as
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in _get_input_from_iterator(iterator)
147 # Validate that all the elements in x and y are of the same type and shape.
148 dist_utils.validate_distributed_dataset_inputs(
--> 149 distribution_strategy_context.get_strategy(), x, y, sample_weights)
150 return x, y, sample_weights
151
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_distributed_dataset_inputs(distribution_strategy, x, y, sample_weights)
309
310 if y is not None:
--> 311 y_values_list = validate_per_replica_inputs(distribution_strategy, y)
312 else:
313 y_values_list = None
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_per_replica_inputs(distribution_strategy, x)
354 if not context.executing_eagerly():
355 # Validate that the shape and dtype of all the elements in x are the same.
--> 356 validate_all_tensor_shapes(x, x_values)
357 validate_all_tensor_types(x, x_values)
358
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/distribute/distributed_training_utils.py in validate_all_tensor_shapes(x, x_values)
371 def validate_all_tensor_shapes(x, x_values):
372 # Validate that the shape of all the elements in x have the same shape
--> 373 x_shape = x_values[0].shape.as_list()
374 for i in range(1, len(x_values)):
375 if x_shape != x_values[i].shape.as_list():
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/tensor_shape.py in as_list(self)
1169 """
1170 if self._dims is None:
-> 1171 raise ValueError("as_list() is not defined on an unknown TensorShape.")
1172 return [dim.value for dim in self._dims]
1173
我已经查看了其他几个 Stackoverflow 帖子(此处和此处)出现此错误,但在我的情况下,我认为问题出在我传递给我的 DataSets 的 map 函数中。我把process_path
下面定义的map
函数调用为tensorflow DataSet的函数。这接受图像的路径并构造到相应的分割掩码的路径,即 a numpy file
。然后将 numpy 文件中的 (256 256) 数组转换为 (256 256 10),kerasUtil.to_categorical
其中 10 个通道代表每个类。我使用该check_shape
函数来确认张量形状是正确的,但是当我调用model.fit
形状时仍然无法导出。
# --------------------------------------------------------------------------------------
# DECODE A NUMPY .NPY FILE INTO THE REQUIRED FORMAT FOR TRAINING
# --------------------------------------------------------------------------------------
def decode_npy(npy):
filename = npy.numpy()
data = np.load(filename)
data = kerasUtils.to_categorical(data, 10)
return data
def check_shape(image, mask):
print('shape of image: ', image.get_shape())
print('shape of mask: ', mask.get_shape())
return 0.0
# --------------------------------------------------------------------------------------
# DECODE AN IMAGE (PNG) FILE INTO THE REQUIRED FORMAT FOR TRAINING
# --------------------------------------------------------------------------------------
def decode_img(img):
# convert the compressed string to a 3D uint8 tensor
img = tf.image.decode_png(img, channels=3)
# Use `convert_image_dtype` to convert to floats in the [0,1] range.
return tf.image.convert_image_dtype(img, tf.float32)
# --------------------------------------------------------------------------------------
# PROCESS A FILE PATH FOR THE DATASET
# input - path to an image file
# output - an input image and output mask
# --------------------------------------------------------------------------------------
def process_path(filePath):
parts = tf.strings.split(filePath, '/')
fileName = parts[-1]
parts = tf.strings.split(fileName, '.')
prefix = tf.convert_to_tensor(convertedMaskDir, dtype=tf.string)
suffix = tf.convert_to_tensor("-mask.npy", dtype=tf.string)
maskFileName = tf.strings.join((parts[-2], suffix))
maskPath = tf.strings.join((prefix, maskFileName), separator='/')
# load the raw data from the file as a string
img = tf.io.read_file(filePath)
img = decode_img(img)
mask = tf.py_function(decode_npy, [maskPath], tf.float32)
return img, mask
# --------------------------------------------------------------------------------------
# CREATE A TRAINING and VALIDATION DATASETS
# --------------------------------------------------------------------------------------
trainSize = int(0.7 * DATASET_SIZE)
validSize = int(0.3 * DATASET_SIZE)
allDataSet = tf.data.Dataset.list_files(str(imageDir + "/*"))
# allDataSet = allDataSet.map(process_path, num_parallel_calls=AUTOTUNE)
# allDataSet = allDataSet.map(process_path)
trainDataSet = allDataSet.take(trainSize)
trainDataSet = trainDataSet.map(process_path).batch(64)
validDataSet = allDataSet.skip(trainSize)
validDataSet = validDataSet.map(process_path).batch(64)
...
# this code throws the error!
model_history = model.fit(trainDataSet, epochs=EPOCHS,
steps_per_epoch=stepsPerEpoch,
validation_steps=validationSteps,
validation_data=validDataSet,
callbacks=callbacks)
解决方案
我在图像和蒙版方面遇到了与您相同的问题,并通过在预处理功能期间手动设置它们的形状来解决它,特别是在 tf.map 期间调用 pyfunc 时。
def process_path(filePath):
...
# load the raw data from the file as a string
img = tf.io.read_file(filePath)
img = decode_img(img)
mask = tf.py_function(decode_npy, [maskPath], tf.float32)
# TODO:
img.set_shape([MANUALLY ENTER THIS])
mask.set_shape([MANUALLY ENTER THIS])
return img, mask
推荐阅读
- node.js - Node.js 内部双工流是否等同于 duplexify 库?
- json - 未处理的异常:“String”类型不是“int”类型的子类型,即使数据中没有整数
- flutter - 如何在颤动中创建一侧倾斜的容器
- php - 混合内容 Laravel API 端点
- html - 阻止 Microsoft Edge 翻译器翻译页面
- linux - 如何解决奇怪的错误“不允许操作”
- coq - Coq 幂运算符“^”未找到
- java - BIRT 报告如何从数据库中解密密码并显示
- mongodb - 如何最好地将这个带有 covid19 数据的 .csv 转换为 Mongoose 模式?
- mysql - Mysql Copy OR Migrate only one table which contains image