python - tf.argmax() 用于多个索引 Tensorflow
问题描述
在 Tensorflow 中,tf.argmax() 返回数组中最大元素的索引。
但是,对于多标签分类任务,返回数组中 N 个最大元素的函数会非常方便。
predicted_array: [0.4, 0.6, 0.7, 0.2, 0.9]
tf.something(predicted_array, N = 2): [2,4]
然后将其与地面实况进行比较,一个热编码数组
one_hot_array: [0, 0, 1, 0, 1]
tf.something(one_hot_array, N = 2): [2,4]
有没有这样的功能?或者类似的东西?
谢谢你的帮助
解决方案
就在这里。它是tf.nn.top_k
(从这里)。
您可以将其用作tf.nn.top_k(predicted_array, k=2)
推荐阅读
- python - 如何在 ubuntu 18.04 64bit 上安装 pycurl 7.43
- android - java.lang.NullPointerException:尝试在空对象引用上调用虚拟方法“void android.view.ViewGroup.addView(android.view.View)”
- javascript - Paypal API 集成错误:使用“details.seller_receivable_breakdown.gross_amount”表示未定义
- java - 在Android 10中请求位置坐标权限的正确方法
- typescript - 如何将用户状态设置为空?
- string - 如何将字符串值转换为数字 - Google 跟踪代码管理器
- html - “[HMR] 等待来自 WDS 的更新信号……”究竟是什么意思
- extjs - 启用按钮时不显示工具提示
- neo4j - 什么时候应该使用推断的关系和节点而不是显式的关系和节点?
- c# - AWS Lambda 不会返回查询