首页 > 解决方案 > Tensorflow 2:序列化和解码时的形状不匹配

问题描述

我有一个形状为 (300,256,256) 的张量 A。我想序列化 A 以保存为 tfrecord 格式。但我无法将其转换回具有相同形状的张量。

A = tf.convert_to_tensor( *a numpy array with float32 type* )
B = tf.io.serialize_tensor(A)
C = tf.reshape(tf.io.decode_raw(B, out_type=tf.float32),[300,256,256])

如果我运行上面的代码,我会得到一个形状错误:

tensorflow.python.framework.errors_impl.InvalidArgumentError:reshape 的输入是一个具有 19660806 个值的张量,但请求的形状有 19660800 [Op:Reshape]

似乎当我序列化或解码时,添加了 6 个浮点数。(很奇怪)

标签: pythonnumpytensorflowtensorflow2.0

解决方案


尝试使用: tf.io.parse_tensor(),而不是tf.io.decode_raw().

https://www.tensorflow.org/api_docs/python/tf/io/parse_tensor


推荐阅读