python-3.x - 使用gather_nd在循环中获取结果-所有输入的形状必须匹配
问题描述
我在 numpy 中有这个例子:
import numpy as np
import tensorflow as tf
a = np.array([[1, 2, 3],
[4, 5, 6],
[7, 8, 9],
[10, 11 , 12],
[13, 14, 15]])
res = np.zeros((5, 2), dtype=object)
for idx in range(0, len(a)-2, 2):
a0 = a[idx]
a1 = a[idx + 1]
a2 = a[idx + 2]
c = a0 + a1 + a2
res[idx:idx + 2] = ([idx, c])
res
array([[0, array([12, 15, 18])],
[0, array([12, 15, 18])],
[2, array([30, 33, 36])],
[2, array([30, 33, 36])],
[0, 0]], dtype=object)
我想在张量流中做到这一点:
a_tf = tf.convert_to_tensor(a)
res_tf = tf.zeros((5, 2), dtype=object)
for idx in range(0, a.shape[0]-2, 2):
a0 = tf.gather_nd(a, [idx])
a1 = tf.gather_nd(a, [idx + 1])
a2 = tf.gather_nd(a, [idx + 2])
c = a0 + a1 + a2
res = tf.gather_nd([idx, c], [idx:idx +2])
直到与计算一致为止c
。
在最后一行 ( res
) 它给了我:
res = tf.gather_nd([idx, c], [idx:idx +2])
^
SyntaxError: invalid syntax
我不确定如何收到结果。
更新
基本上,问题在于它[idx, c]
的类型是 list 并试图做 : tf.convert_to_tensor([idx, c]
,给出:
InvalidArgumentError: Shapes of all inputs must match: values[0].shape = [] != values[1].shape = [3] [Op:Pack] name: packed/
解决方案
res = tf.gather_nd([idx, c], [idx:idx +2])
语法不正确。如果你想提取索引,它应该是
res = tf.gather_nd([idx, c], range(idx, idx +2))
后者也可能会引发错误。中的索引range(idx, idx +2)
高于列表的索引[idx, c]
。
res
此外,除非使用,否则无法创建形状为 的张量ragged tensors
。这是您尝试执行的操作的可能解决方法
a_tf = tf.convert_to_tensor(a)
res_tf = tf.zeros((5, 2), dtype=object)
l = []
for idx in range(0, a.shape[0]-2, 2):
a0 = tf.gather_nd(a, [idx])
a1 = tf.gather_nd(a, [idx + 1])
a2 = tf.gather_nd(a, [idx + 2])
c = a0 + a1 + a2
helper = [idx]
helper.extend(c.numpy().tolist())
l.append(helper)
print(tf.constant(l))
推荐阅读
- c# - 将字节从一个数组复制到另一个数组
- ruby-on-rails - 有没有办法将几个 where 子句与 Ruby on Rails 中的内部 if 语句结合起来?
- javascript - nuxt 通用模式,console.log 输出文件名
- delphi - Windows 防火墙添加带有远程 IP 的路由?
- javascript - 如何使用 Reactjs 或 Javascript 转换 wav 文件中的音频块?
- regex - 捕获组的负后视
- html - Angular 网页需要刷新才能完全加载
- regex - 如何找到双字母并用三字母替换它们?
- php - woocommerce会话密钥重复
- node.js - 如何操作正则表达式命名的捕获组