首页 > 解决方案 > 使用 tf.cond() 时,Tensorflow 报告“TypeError:预期单个张量时的张量列表”

问题描述

我正在使用 Tensorflow 编写模型。我的条件语句的一部分,例如:

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: tf.constant([1, src_shape[0]]))

并且src_shape是 的结果tf.shape()

它报告TypeError: List of Tensors when single Tensor expected。我知道这是因为tf.constant([1, src_shape[0]])是张量列表,但我不知道如何以合法的方式实现我的代码。

我试图删除tf.constant()喜欢

new_shape = tf.cond(tf.equal(tf.shape(src_shape)[0], 2), lambda: src_shape, lambda: [1, src_shape[0]])

但它报告ValueError: Incompatible return values of true_fn and false_fn: The two structures don't have the same nested structure.

标签: pythontensorflow

解决方案


一种方法是使用 tf.stack,它将 rank-R 张量列表堆叠成一个 rank-(R+1) 张量。

lambda: tf.stack([1, src_shape[0]], axis=0)

另一种解决方案是使用 tf.concat 使用正确的 tf.reshape 命令。


推荐阅读