首页 > 解决方案 > tensorflow One Hot Encodings 的使用

问题描述

我收到此功能的 AssertionError ......我该如何解决这个问题

def one_hot_matrix(label, depth=6):
     one_hot = tf.one_hot(label, depth, axis = 0)
     one_hot = tf.reshape(one_hot, (-1,1))
     return one_hot
def one_hot_matrix_test(target):
    label = tf.constant(1)
    depth = 4
    result = target(label, depth)
    print("Test 1:",result)
    assert result.shape[0] == depth, "Use the parameter depth"
    assert np.allclose(result, [0., 1. ,0., 0.] ), "Wrong output. Use tf.one_hot"
     label_2 = [2]
    result = target(label_2, depth)
    print("Test 2:", result)
    assert result.shape[0] == depth, "Use the parameter depth"
    assert np.allclose(result, [0., 0. ,1., 0.] ), "Wrong output. Use tf.reshape as instructed" 
    print("\033[92mAll test passed")

​</p>

标签: pythontensorflowone-hot-encoding

解决方案


大多数情况下,根据结果的形状,您会遇到断言错误。
为此,您使用

one_hot = tf.reshape(one_hot, (depth,))

推荐阅读