python - tf.linalg.eigh 在 GPU 上非常慢 - 正常吗?
问题描述
所以我刚刚找到了导致我在 GPU 上的代码变慢的罪魁祸首:tf.linalg.eigh()
.
这个想法很简单:我创建 - 比方说 - 87.000 个 4x4 Hermitian 矩阵,并希望获得它们的特征值和特征向量。为此,我有一个matrix
形状为 [87.000,4,4] 的占位符,我将其输入tf.linalg.eigh(matrix)
. 我运行 Session 并将这些矩阵作为输入提供(矩阵的数据类型为 complex64),并希望输出特征值。
这需要 8 核 CPU 不到 0.04 秒,而 GPU 需要 19 秒 - NumPy 大约需要 0.4 秒。
所以我的问题是:为什么tf.linalg.eigh()
GPU 上的速度如此之慢,即使一个提供了大批量。即使一个矩阵的对角化不能有效地并行化,在数千个矩阵的情况下,GPU 仍然应该快得多......
可以以某种方式解决此问题,还是我必须从 GPU 切换到 CPU 才能执行此操作?
到代码:
进口
import numpy as np
from matplotlib.ticker import LinearLocator, FormatStrFormatter
import tensorflow as tf
config = tf.ConfigProto(device_count = {'GPU': 1})
sess = tf.Session(config=config)
import time
tf 部件的构建
matrix=tf.placeholder(tf.complex64,shape[None,87,4,4],name="matrix")
eigenval,eigenvec=tf.linalg.eigh(tf.linalg.adjoint(matrix))
init = tf.global_variables_initializer()
sess.run(init)
complex_matrix=np.ones((10000,87,4,4))+1j*np.ones((batch_net,path_length,num_orbits,num_orbits))
运行操作并测量时间
t1=time.time()
sess.run(eigenvec,feed_dict={matrix: complex_matrix, eigenvalues_true: eigenvalues })
print(time.time()-t1)
解决方案
经过一些试验,我认为在这种情况下,最好将此操作放在 CPU 上。关键是 PCI-GPU 通信是这里的一个瓶颈,所以你根本无法获得良好的 GPU 利用率。尽管可以通过在 GPU 上使用 TF op 生成随机矩阵来减小此开销
with tf.device('/device:GPU:0'):
matrix = tf.random.uniform((87000,4,4), minval=0.1, maxval=0.99, dtype=tf.float32)
eigenval,eigenvec=tf.linalg.eigh(matrix)
它只允许在我的系统上减少大约 40% 的计算时间,这仍然比 CPU 慢得多。您也可以尝试将张量分成相等的块,执行linalg.eigh
和连接结果,但这也几乎没有任何改进
matrix = tf.random.uniform((87000,4,4), minval=0.1, maxval=0.99, dtype=tf.float32)
result = tf.concat([tf.linalg.eigh(x)[1] for x in tf.split(matrix, 1000, axis=0)], axis=0)
我还注意到,linalg.eigh
在 CPU 上执行的缩放是近似对数的,而 GPU 操作似乎是线性的。希望这可以帮助!
一点更新。看起来SelfAdjointEigV2
XLA 编译器甚至不支持操作,所以这段代码
matrix = tf.random.uniform((87000, 4, 4), minval=0.1, maxval=0.99, dtype=tf.float32)
def xla_test(matrix):
eigenval, eigenvec = tf.linalg.eigh(matrix)
return eigenvec
y = xla.compile(xla_test, inputs=[matrix])
抛出“检测到不支持的操作”错误
推荐阅读
- azure - 部署后 Web App 未加载
- .net - 带有发布配置的 SourceLink
- crystal-reports - 您可以在晶体中进行行数以选择显示数据的位置吗?
- scala - 防止 log4j 评估比全局更高级别的跟踪
- signtool - Signtool:自 Windows 10 更新 1803 以来:未找到符合所有给定条件的证书
- node.js - 使用打字稿设置 Firebase 服务帐户时出错
- c# - 如何使用 LINQ 返回包含 List<> 中最新项目的 List<>
- c# - 巨大的内存消耗(MVC 5 + EntityFramework 6.2.0)
- go - go get 的 TLS 问题
- php - PHP父方法覆盖不起作用(使用命名空间)