python - 如何在 TensorFlow 中编写 argmax 函数?
问题描述
我说的不是tf.argmax
数学意义上的 argmax,例如,给定一组离散的值,一个函数找到最大化它的值。我目前有这样的东西
input = tf.placeholder(name='input')
Qhat = do_stuff_to(input) # e.g. tf.add(input, 3)
现在我想定义另一个 TensorFlow 节点 ,max_Qhat
它将以张量数组作为其参数。它将把这些张量中的每一个都输入Qhat
并返回产生最大值的那个。我该怎么做呢?(注意我不想运行Qhat,所以不session.run
。我只想定义一个评估它的函数。)
我到目前为止的代码:
inputs = tf.placeholder(name='inputs')
max_Qhat = ???
解决方案
创建一个输入张量,比如维度[k, ...]
(这是我们“最大化”的输入张量的“数组” k
),然后计算输入张量的第一个轴/维度上的 Qhat 操作(例如使用tf.map_fn
),使其返回一个维度的函数值的张量,[k]
然后确定此返回张量的最大值。
import tensorflow as tf;
sess = tf.InteractiveSession();
# define inputs
inputs = tf.constant([
[1, 2, 3, 4, 5],
[4, 3, 6, 2, 1],
[9, 9, 9, 9, 9],
[0, 1, 3, 5, 2]
], shape = (4, 5));
# the function/op that we want to compute for each input tensor
def some_op(t):
return tf.reduce_sum(t);
# compute the function values
q = tf.map_fn(some_op, inputs);
# determine the index of the input tensor that maximizes the function
index = sess.run(tf.argmax(q, axis = 0));
maximizer = sess.run(inputs[index]);
推荐阅读
- php - php foreach 循环中的 HTML 表格,其中仅当元素等于表格列标题名称时才会填充单元格数据
- react-native - React Native - 运行并行 Axios / 获取请求
- fenics - 如何在 GMSH GUI 中创建构面?
- jquery - 绑定选择和单选按钮的问题
- pyqt5 - 如何使用 Bluez DBus API 和 Qt DBus 构建 BLE 广告
- python - 如何从数据框的列创建多行
- javascript - 我什么时候应该使用 async/await,什么时候应该使用回调 [js]?
- c# - Worker 服务,如何让 MQTT 事件订阅不中断?
- .net - 在 WSDL 中重命名 PortType(Interface name) 会导致兼容性问题?
- python - 对 ModelForm 中的同一字段使用 django-autocomplete-light 和 django-addanother