python - 使用 tf.while_loop 对张量进行切片
问题描述
只要仍有一些列使用tf.while_loop
.
注意:我使用这种方式是因为我无法在图形构建时(没有会话)循环占位符中的值,该值被视为张量而不是整数。
[ 5 7 8 ]
[ 7 4 1 ] =>
[5 7 ] [ 7 8 ]
[7 4 ] [ 4 1 ]
这是我的代码:
i = tf.constant(0)
result = tf.subtract(tf.shape(f)[1],1)
c = lambda result : tf.greater(result, 0)
b = lambda i: [i+1, tf.slice(x, [0,i],[2, 3])]
o= tf.while_loop(c, b,[i])
with tf.Session() as sess:
print (sess.run(o))
但是,我收到此错误:
ValueError: The two structures don't have the same nested structure.
First structure: type=list str=[<tf.Tensor 'while_2/Identity:0' shape=() dtype=int32>]
Second structure: type=list str=[<tf.Tensor 'while_2/add:0' shape=() dtype=int32>, <tf.Tensor 'while_2/Slice:0' shape=(2, 3) dtype=int32>]
我想每次都返回子张量
解决方案
您的代码有几个问题:
您没有传递任何结构/张量来接收
tf.slice(...)
. 您lambda
b
应该有一个签名,例如lambda i, res : i+1, ...
通过 a 编辑的张量
tf.while_loop
应该具有固定的形状。如果你想建立一个循环来收集切片,那么你应该首先res
用适当的形状初始化张量以包含所有切片值,例如res = tf.zeros([result, 2, 2])
.
注意:关于您的特定应用程序(收集所有相邻列对),这可以在没有tf.while_loop
:
import tensorflow as tf
x = tf.convert_to_tensor([[ 5, 7, 8, 9 ],
[ 7, 4, 1, 0 ]])
num_rows, num_cols = tf.shape(x)[0], tf.shape(x)[1]
# Building tuples of neighbor column indices:
n = 2 # or 5 cf. comment
idx_neighbor_cols = [tf.range(i, num_cols - n + i) for i in range(n)]
idx_neighbor_cols = tf.stack(idx_neighbor_cols, axis=-1)
# Finally gathering the column pairs accordingly:
res = tf.transpose(tf.gather(x, idx_neighbor_cols, axis=1), [1, 0, 2])
with tf.Session() as sess:
print(sess.run(res))
# [[[5 7]
# [7 4]]
# [[7 8]
# [4 1]]
# [[8 9]
# [1 0]]]
推荐阅读
- postgresql - 如何在 CentOS 上升级 PostgreSQL 的次要版本
- python - 如何找到没有孔的 3D 对象的整个外壳?
- grafana - 如何使用 grafana 显示所有动态指标?
- php - 这似乎是一种允许用户公开共享数据的安全方式吗?
- c++ - Linux 上的 .NET Core - 元帅结构
- sql - 如何在sql中获取最近3天的数据
- sql-server - FreeTDS 和 unixodbc 安装和配置
- angular - 如何在 TypeScript 中使用附加属性扩展 fabric.Object
- javascript - 类型上不存在属性“defaultProps”
- python - 替换 Pandas 数据框中的值