首页 > 解决方案 > 如何自己实现 tf.argmax?

问题描述

我想使用一个函数,该函数将张量作为输入并返回跨张量轴的最大值的索引。我知道存在一个功能完全相同的函数 tf.argmax(),但是我如何自己实现它(如果实现一些自定义函数,这可能是必要的)?

现在让我们假设该函数仅将一维张量作为输入。因此,该函数需要具有以下签名:

argmax(
    input, #input is a 1D tensor
    name=None
)

我尝试以这种方式实现它:

def argmax(input, name=None):
    maxValue=0
    maxIndex=0
    for i in range(input.get_shape()[0]):
        if input[i]>maxValue:
            maxValue=input[i]
            maxIndex=i
    return maxIndex

但是这不起作用,因为在构建阶段,这些值尚未初始化,因此我无法像在上面的代码中那样比较两个值。那么,有没有一种方法可以让我们写出自定义函数,如 tf.argmax、tf.equal 等?

标签: pythontensorflowmachine-learning

解决方案


Well, one simple way would be this:

idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]

Example:

import tensorflow as tf

with tf.Session() as sess:
    input = tf.constant([1, 3, 4, 2, 1, 2])
    idx = tf.where(tf.equal(input, tf.reduce_max(input)))[0, 0]
    print(sess.run(idx))

Output:

2

推荐阅读