python - 仅获取张量一部分的对角线元素
问题描述
我有一个形状为 (?, a, a, b) 的张量元素。我想将其转换为形状 (?, a, b) 的张量,其中:
output[ i , j , k ] = input[ i , j , j , k ].
这在 numpy 中很简单,因为我可以通过循环 i、j、k 来分配元素。但是,所有操作都必须保留为 Tensorflow 中的张量,因为它需要评估成本函数和训练模型。
我已经看过tf.diag_part()但据我了解,这不能在特定轴上指定,必须对整个张量进行。
解决方案
因为,就像你说的那样,tf.diag_part
不允许轴,它在这里似乎没有用。这是一种可能的解决方案tf.gather_nd
:
import tensorflow as tf
import numpy as np
# Input data
inp = tf.placeholder(tf.int32, [None, None, None, None])
# Read dimensions
s = tf.shape(inp)
a, b, c = s[0], s[1], s[3]
# Make indices for gathering
ii, jj, kk = tf.meshgrid(tf.range(a), tf.range(b), tf.range(c), indexing='ij')
idx = tf.stack([ii, jj, jj, kk], axis=-1)
# Gather result
out = tf.gather_nd(inp, idx)
# Test
with tf.Session() as sess:
inp_val = np.arange(36).reshape(2, 3, 3, 2)
print(inp_val)
# [[[[ 0 1]
# [ 2 3]
# [ 4 5]]
#
# [[ 6 7]
# [ 8 9]
# [10 11]]
#
# [[12 13]
# [14 15]
# [16 17]]]
#
#
# [[[18 19]
# [20 21]
# [22 23]]
#
# [[24 25]
# [26 27]
# [28 29]]
#
# [[30 31]
# [32 33]
# [34 35]]]]
print(sess.run(out, feed_dict={inp: inp_val}))
# [[[ 0 1]
# [ 8 9]
# [16 17]]
#
# [[18 19]
# [26 27]
# [34 35]]]
这里有几个替代版本。一种使用张量代数。
inp = tf.placeholder(tf.int32, [None, None, None, None])
b = tf.shape(inp)[1]
eye = tf.eye(b, dtype=inp.dtype)
inp_masked = inp * tf.expand_dims(eye, 2)
out = tf.tensordot(inp_masked, tf.ones(b, inp.dtype), [[2], [0]])
还有一个使用布尔掩码:
inp = tf.placeholder(tf.int32, [None, None, None, None])
s = tf.shape(inp)
a, b, c = s[0], s[1], s[3]
mask = tf.eye(b, dtype=tf.bool)
inp_mask = tf.boolean_mask(inp, tf.tile(tf.expand_dims(mask, 0), [a, 1, 1]))
out = tf.reshape(inp_mask, [a, b, c])
编辑:我对这三种方法进行了一些时间测量:
import tensorflow as tf
import numpy as np
def f1(inp):
s = tf.shape(inp)
a, b, c = s[0], s[1], s[3]
ii, jj, kk = tf.meshgrid(tf.range(a), tf.range(b), tf.range(c), indexing='ij')
idx = tf.stack([ii, jj, jj, kk], axis=-1)
return tf.gather_nd(inp, idx)
def f2(inp):
b = tf.shape(inp)[1]
eye = tf.eye(b, dtype=inp.dtype)
inp_masked = inp * tf.expand_dims(eye, 2)
return tf.tensordot(inp_masked, tf.ones(b, inp.dtype), [[2], [0]])
def f3(inp):
s = tf.shape(inp)
a, b, c = s[0], s[1], s[3]
mask = tf.eye(b, dtype=tf.bool)
inp_mask = tf.boolean_mask(inp, tf.tile(tf.expand_dims(mask, 0), [a, 1, 1]))
return tf.reshape(inp_mask, [a, b, c])
with tf.Graph().as_default():
inp = tf.constant(np.arange(100 * 300 * 300 * 10).reshape(100, 300, 300, 10))
out1 = f1(inp)
out2 = f2(inp)
out3 = f3(inp)
with tf.Session() as sess:
v1, v2, v3 = sess.run((out1, out2, out3))
print(np.all(v1 == v2) and np.all(v1 == v3))
# True
%timeit sess.run(out1)
# CPU: 1 ms ± 138 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# GPU: 1.04 ms ± 93.7 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sess.run(out2)
# CPU: 1.17 ms ± 150 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# GPU: 734 ms ± 17.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
%timeit sess.run(out3)
# CPU: 1.11 ms ± 172 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)
# GPU: 1.41 ms ± 1.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
似乎所有三个在 CPU 上都相似,但第二个在我的 GPU 中由于某种原因速度较慢。虽然不确定浮点值的结果是什么。此外,您可以尝试用 替换tf.tensordot
,tf.einsum
例如。关于第一个和第二个,它们看起来都很好,尽管如果您通过这些操作进行反向传播,计算梯度的成本可能会有所不同。
推荐阅读
- mysql - 获取 mysql.connector.errors.ProgrammingError: 1045 (28000): Access denied for user 'vsearch'@'localhost' (using password: YES) with python
- javascript - JS数组合并求和
- javascript - 有没有一种方便的方法来引用 Svelte 组件中的 DOM 元素?
- javascript - 如何显示用户选择上传的图像
- java - Java 中循环技巧的 Helloworld 示例
- c++ - 从 txt 文件中删除注释
- ios - 如何在 UITableView 中执行网络调用
- spring - 如何使用 Spring Boot 资源服务器拥有未经身份验证的端点
- python - 如何在 Python 中读取和写入 config.ini 文件?
- java - 使用 java 和 jsoup从网站上的标签中提取 src 值