首页 > 解决方案 > 如何在 tf.Tensor 中找到最大值?

问题描述

如何找到每个元素中的最大值以便得到 2、4、6、8?

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

我尝试了以下代码:

tf.reduce_max(a, keepdims=True)

但这只是给了我 8 作为输出,而忽略了其余的。

标签: pythontensorflowtensorflow2.0

解决方案


您必须axis像这样将参数更改为 -1:

import tensorflow as tf
a = tf.constant([
  [[1, 2]], [[3, 4]],
  [[5, 6]], [[7, 8]]])

print(tf.reduce_max(a, axis=-1, keepdims=False))

'''
tf.Tensor(
[[2]
 [4]
 [6]
 [8]], shape=(4, 1), dtype=int32)
'''

因为你有一个 3D 张量并且想要访问最后一个维度。


推荐阅读