首页 > 解决方案 > 如何从张量中获取前 3 个最大数字

问题描述

我如何从中获得前 3 个最大数字y_classe = tf.argmax(preds, axis=1, output_type=tf.int32)

标签: pythontensorflowtensor

解决方案


您可以使用tf.math.top_k

import tensorflow as tf

y_pred = [[-18.6, 0.51, 2.94, -12.8]]

max_entries = 3

values, indices = tf.math.top_k(y_pred, k=max_entries)
print(values)
print(indices)
tf.Tensor([[  2.94   0.51 -12.8 ]], shape=(1, 3), dtype=float32)
tf.Tensor([[2 1 3]], shape=(1, 3), dtype=int32)

推荐阅读