python - 使用 tf.roll 增加移位值和堆栈
问题描述
我有以下情况。我有一个i
坐标列表,(x, y, z)
必须计算截止球内的所有三元组,使得r_ij
和r_ik
小于截止值。因此,我正在计算一个r_ij
包含所有距离的矩阵。要计算三元组,我的想法是构造一个r_ijk
矩阵。
我已经通过对元素数量的循环来完成此i
操作
import tensorflow as tf
r_ij = tf.reshape(tf.range(4*4), (4, 4))
r_ijk = []
for i in range(len(x)):
r_ijk.append(tf.roll(r_ij, shift=-i, axis=1))
tf.stack(r_ijk)
由于两个问题,我想改进此代码。主要是因为我假设它可以完全矢量化。但也要在我的模型中使用它,我需要更改它:
@tf.function
def get_triplets(full_r_ij, r_cut):
r_ij = tf.norm(full_r_ij, axis=-1) # Shape of full_r_ij is (n_timesteps, n_atoms, n_atoms, 3)
n_atoms = tf.shape(r_ij)[1]
r_ijk = r_ij[None]
for atom in range(1, n_atoms):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(r_ijk, tf.TensorShape([None, None, None, None]))]
)
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = tf.concat([r_ijk, tmp[None]], axis=0) # shape is (n_atoms, n_timesteps, n_atoms, n_atoms)
r_ijk = tf.transpose(r_ijk, perm=(1, 0, 2, 3))
r_ijk = tf.where(r_ijk == 0, tf.ones_like(r_ijk) * r_cut, r_ijk)
intermediate_indices = tf.where(
tf.math.logical_and(r_ijk[:, 0, None] == 3.0, r_ijk[:, 1:] == 3.0)
)
n_atoms = tf.cast(n_atoms, dtype=tf.int64)
t, n, i, j = tf.unstack(intermediate_indices, axis=1)
k = j + n + 1
k = tf.where(k >= n_atoms, k - n_atoms, k)
triples = tf.stack([t, i, j, k], axis=1)
return triples
并使用tf.autograph.experimental.set_loop_options
,因为我有点循环 r_ij 张量。有没有办法改进第一个代码示例(或第二个)?
解决方案
我测试了两个进一步的实现tf.vectorized_mad
,tf.map_fn
它们的性能都比我写的初始函数差。所有测试均使用r_ij = tf.random.normal((32, 150, 150))
@tf.function
def roll_loop(r_ij, n_atoms):
r_ijk = r_ij[None]
for atom in range(1, n_atoms):
tf.autograph.experimental.set_loop_options(
shape_invariants=[(r_ijk, tf.TensorShape([None, None, None, None]))]
)
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = tf.concat([r_ijk, tmp[None]], axis=0) # shape is (n_atoms, n_timesteps, n_atoms, n_atoms)
return r_ijk
花了129 ms ± 1.98 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
@tf.function
def roll_vect(r_ij, n_atoms):
r_ijk = tf.repeat(r_ij[None], repeats=n_atoms, axis=0)
def roll(args):
x, shift = args
return tf.roll(x, shift=shift, axis=2)
return tf.vectorized_map(roll, [r_ijk, tf.range(n_atoms)])
花了225 ms ± 15.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
@tf.function
def roll_map(r_ij, n_atoms):
r_ijk = tf.repeat(r_ij[None], repeats=n_atoms, axis=0)
def roll(args):
x, shift = args
return tf.roll(x, shift=shift, axis=2)
return tf.map_fn(roll, (r_ijk, tf.range(n_atoms)), fn_output_signature=tf.float32)
花了327 ms ± 18 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
因此,似乎tf.function
使用 python for 循环是最快的(到目前为止)。所有函数都在测试前编译。
编辑:使用tf.TensorArray
似乎是完成这项任务的最佳方式。我用几个不同的输入对其进行了测试,它的性能与tf.autograph.experimental.set_loop_options
@tf.function
def roll_loop(r_ij, n_atoms):
r_ijk = tf.TensorArray(tf.float32, size=n_atoms)
for atom in range(0, n_atoms):
tmp = tf.roll(r_ij, shift=-atom, axis=2)
r_ijk = r_ijk.write(atom, tmp)
return r_ijk.stack()
花了128 ms ± 1.33 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
推荐阅读
- c - 错误:“int”类型的参数与“uint8_t *”类型的参数不兼容
- .net-core - HttpClient.PutAsync 返回 400 Bad Request (.Net Core / Blazor)
- python - Python 组列表到子列表的列表是单调的,元素之间的差异相等
- python - 从 Arduino 绘制实时串行数据,除非手动关闭,否则绘图不会更新?
- python - numpy数组的唯一标识符?
- php - 如何在管道到 PHP 进程后保持输入流打开?
- mongodb - 用于 Mongodb 慢查询的 Logstash Grok 过滤器
- mysql - 如果 id 在另一个表中,Mysql 用值更新列
- reactjs - tsdx 和 babel 无法使用意外令牌构建典型的 TypeScript
- android - 如何在离子+反应+电容器项目中设置最低android版本支持?