numpy - 出乎意料的tenorflow(1.12)广播效率
问题描述
在使用 tensorflow 实现像带有潜在向量的 FM 这样的模型时,我遇到了意想不到的乘法广播效率问题:相同的数组在 diff 转置后乘法成本非常不同。
但是,使用 numpy,效率差异并不显着。
那么,tensorflow 1.12 和 numpy 之间是否存在一些广播规则差异?
PS:tf1.14 & tf2,工作正常,谁知道哪个重要的更新解决了这个问题?
简单代码:
[b, d, n, 1] * [n, k] # ok
[b, n, d, 1] * [n, 1, k] # slow
[b, n, 1, d] * [n, k, 1] # very very slow
整个代码:
import tensorflow as tf
import numpy as np
from time import time
import timeit
sess = tf.InteractiveSession()
batch_size = 1024
k = 8
d = 32 # emb_size
n = 20 # slot count
input_var = np.random.randn(batch_size, n, d, 1)
v_nk = np.random.randn(n, k)
v_nkd = np.random.randn(n, k, d)
v_nk1 = np.reshape(v_nk, [n, k, 1])
v_n1k = np.reshape(v_nk, [n, 1, k])
input_var_bdn1 = np.transpose(input_var, [0, 2, 1, 3]).copy() # [b, d, n, 1]
input_var_bn1d = np.transpose(input_var, [0, 1, 3, 2]).copy() # [b, n, 1, d]
input_var_b1nd = np.transpose(input_var, [0, 3, 1, 2]).copy() # [b, 1, n, d]
# numpy
print('with numpy: ')
print ("X_nk COST: ",timeit.timeit(lambda: input_var_bdn1 * v_nk, number=100 )) # 3.1s
print ("X_nk1 COST: ",timeit.timeit(lambda: input_var_bn1d * v_nk1, number=100 )) # 2.5s
print ("X_n1k COST: ",timeit.timeit(lambda: input_var * v_n1k, number=100 )) # 3.0s
print ("X_nkd COST: ",timeit.timeit(lambda: input_var_bn1d * v_nkd, number=100 )) # 2.5s
input_var = tf.constant(input_var)
input_var_bdn1 = tf.constant(input_var_bdn1)
input_var_bn1d = tf.constant(input_var_bn1d)
input_var_b1nd = tf.constant(input_var_b1nd)
v_nk = tf.constant(v_nk)
v_nk1 = tf.constant(v_nk1)
v_n1k = tf.constant(v_n1k)
v_nkd = tf.constant(v_nkd)
input_X_nk = input_var_bdn1 * v_nk
input_X_n1k = input_var * v_n1k
input_X_nk1 = input_var_bn1d * v_nk1
input_X_nkd = input_var_bn1d * v_nkd
print()
print('with tf: ')
print ("X_nk COST: ",timeit.timeit(lambda: sess.run(input_X_nk), number=100 )) # 0.2s
print ("X_nk1 COST: ",timeit.timeit(lambda: sess.run(input_X_nk1), number=100 )) # 2.2s
print ("X_n1k COST: ",timeit.timeit(lambda: sess.run(input_X_n1k), number=100 )) # 0.6s
print ("X_nkd COST: ",timeit.timeit(lambda: sess.run(input_X_nkd), number=100 )) # 0.55s
for _ in range(10):
input_X_nk += input_var_bdn1 * v_nk
input_X_n1k += input_var * v_n1k
input_X_nk1 += input_var_bn1d * v_nk1
input_X_nkd += input_var_bn1d * v_nkd
print()
print('with tf straightly: ')
print ("X_nk COST: ",timeit.timeit(lambda: sess.run(input_X_nk), number=1 )) # 0.8s
print ("X_nk1 COST: ",timeit.timeit(lambda: sess.run(input_X_nk1), number=1 )) # 6.1s
print ("X_n1k COST: ",timeit.timeit(lambda: sess.run(input_X_n1k), number=1 )) # 1.7s
print ("X_nkd COST: ",timeit.timeit(lambda: sess.run(input_X_nkd), number=1 )) # 1.6s
测试环境:<br> tensorflow 1.12.0 ( tf1.14/tf2 工作正常)
python 3.6.9 && py 2.7.18
Centos 7.4 && Mac 10.14
解决方案
推荐阅读
- python - ValueError: 错误的输入形状 (1, 4)
- javascript - 如何查看 Vite 项目中的公共目录进行热重载?
- python - 在 Python 中实现均值差异的置换测试(链接书中描述的简单过程)
- python - 如何使用 pyarrow 查询镶木地板文件
- html - 无法向仅使用 CSS 创建的菱形添加线条
- class - 如何设置我的 groovy 类,以便我可以在其他 groovy 脚本中使用它的实例
- firebase - 通过 REST 实现 Firebase 实时数据库 OnSnapShot 功能?
- docker - 如何将机密作为环境变量传递到 Docker Github Action 中?
- arrays - 将 unsigned char 数组拆分为多个 unsigned char 数组的数组
- mysql - 如何知道已使用哪种算法加密表中的现有密码