首页 > 解决方案 > 如何构造与 `tf.one_hot` 返回的相同类型的对象?

问题描述

我有一个input_preprocess在数据管道中使用的函数:

def input_preprocess(image, label):
    if label == 1:
        return tf.zeros(NUM_CLASSES)
    else:
        label = tf.one_hot(label, NUM_CLASSES)
    return image, label

问题是tf.one_hot返回的东西是 asequencetf.zeros返回的不是。

我收到以下错误:

 The two structures don't have the same nested structure.
    
    First structure: type=Tensor str=Tensor("cond/zeros_like:0", shape=(28,), dtype=float32)
    
    Second structure: type=tuple str=(<tf.Tensor 'args_0:0' shape=(224, 224, 3) dtype=float32>, <tf.Tensor 'cond/one_hot:0' shape=(28,) dtype=float32>)
    
    More specifically: Substructure "type=tuple str=(<tf.Tensor 'args_0:0' shape=(224, 224, 3) dtype=float32>, <tf.Tensor 'cond/one_hot:0' shape=(28,) dtype=float32>)" is a sequence, while substructure "type=Tensor str=Tensor("cond/zeros_like:0", shape=(28,), dtype=float32)" is not
    Entire first structure:
    .
    Entire second structure:
    (., .)

我如何手动构建可以代表tf.one_hot回报的东西?

标签: pythontensorflowtensorflow2.0tensorflow-datasets

解决方案


tf.one_hot返回一个EagerTensor就像tf.zeros

a = tf.one_hot(1,2)
print(type(a))
b = tf.zeros(2)
print(type(b))
# tensorflow.python.framework.ops.EagerTensor
# tensorflow.python.framework.ops.EagerTensor

我认为您的问题是您的函数 在(第一次返回)时input_preprocess返回单个值,否则返回元组(第二次返回)。label == 1


推荐阅读