首页 > 解决方案 > tf.argmax() 用于多个索引 Tensorflow

问题描述

在 Tensorflow 中,tf.argmax() 返回数组中最大元素的索引。

但是,对于多标签分类任务,返回数组中 N 个最大元素的函数会非常方便。

predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]

然后将其与地面实况进行比较,一个热编码数组

one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]

有没有这样的功能?或者类似的东西?

谢谢你的帮助

标签: pythontensorflowmultilabel-classification

解决方案


就在这里。它是tf.nn.top_k(从这里)。

您可以将其用作tf.nn.top_k(predicted_array, k=2)


推荐阅读