pytorch - Pytorch模型比纯python函数慢10倍
问题描述
我正在尝试将 Stillinger Weber 潜力实现为 pytorch 层。我知道这样的层对于像 Pytorch 这样的 ML 框架可能不是“理想的”,但无论如何我希望它的性能不会比纯 python 或 python/numpy 的实现差。
当我运行与 pytorch nn.Module 继承类相同的层时,我得到单次运行大约 27.977 +- 4.8341 秒(平均超过 10 次运行)。一个简单的 python 类在 2.4446 +- 0.052723 秒内执行完全相同的代码(平均超过 10 次运行)
完整的脚本可以在 GitHub https://github.com/ipcamit/temp_4_pytorch找到
我究竟做错了什么?
以下是pytorch的实现功能
# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines (PyTorch gets fussy if function are part of class)
@torch.jit.script
def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
if rij < cutoff:
sig_r = sigma / rij
one_by_delta_r = 1.0 / (rij - cutoff)
Bpq = (B * sig_r ** p - sig_r ** q)
exp_sigma = torch.exp(sigma * one_by_delta_r )
E2 = A * Bpq * exp_sigma
F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
F = F * (1./sigma) * A * exp_sigma
else:
return torch.tensor(0.0), torch.tensor(0.0)
return E2, F
@torch.jit.script
def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
if ((rij > cutoff_ij) or
(rik > cutoff_ik) or
(rjk > cutoff_jk)):
dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
return torch.tensor(0.0)
else:
cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
cos_diff = cos_beta_ikj - cos_beta0
exp_ij_ik = torch.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))
dij = - gamma_ij/(rij - cutoff_ij)**2
dik = - gamma_ik/(rik - cutoff_ik)**2
E3 = lam * exp_ij_ik * cos_diff ** 2
dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
dcos_drjk = (- rjk) / (rij * rik)
dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
return E3
@torch.jit.script
def energy_and_forces(
nl: List[List[int]],
elements_nl: List[List[int]],
coords_all,
A: List[torch.Tensor],
B: List[torch.Tensor],
p: List[torch.Tensor],
q: List[torch.Tensor],
sigma: List[torch.Tensor],
gamma: List[torch.Tensor],
cutoff: List[torch.Tensor],
lam: List[torch.Tensor],
cos_beta0: List[torch.Tensor],
cutoff_jk: List[torch.Tensor]
):
"""
Calculatd Energy for a given list of coordiates, assuming first coordinate
to be of query atom i, and remaining in the list to be neighbours.
"""
energy = torch.tensor(0.0)
F2 = torch.tensor(0.0)
F3 = torch.zeros(3)
E2 = torch.tensor(0.0)
E3 = torch.tensor(0.0)
gamma_ij = torch.tensor(0.0)
gamma_ik = torch.tensor(0.0)
cutoff_ij = torch.tensor(0.0)
cutoff_ik = torch.tensor(0.0)
xyz_i = torch.zeros(3)
xyz_j = torch.zeros(3)
xyz_k = torch.zeros(3)
rij = torch.zeros(3)
rik = torch.zeros(3)
rjk = torch.zeros(3)
F = torch.zeros_like(coords_all)
F_comp = torch.zeros(3)
for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
num_elem = len(nli)
xyz_i = coords_all[nli[0]]
elem_i = elements[0]
for j in range(1, num_elem):
elem_j = elements[j]
xyz_j = coords_all[nli[j]]
rij = xyz_j - xyz_i
norm_rij = torch.norm(rij)
# if elem_i == elem_j:
ij_sum = elem_i + elem_j
E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
energy = 0.5 * E2
F_comp = 0.5 * F2/norm_rij * rij
F[i,:] = F[i,:] + F_comp
F[nli[j], :] = F[nli[j],:] - F_comp
gamma_ij = gamma[ij_sum]
cutoff_ij = cutoff[ij_sum]
for k in range(j + 1, num_elem):
elem_k = elements[k]
if (elem_i != elem_j) and \
(elem_j == elem_k):
ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
ik_sum = elem_i + elem_k
xyz_k = coords_all[nli[k]]
rik = xyz_k - xyz_i
norm_rik = torch.norm(rik)
rjk = xyz_k - xyz_j
norm_rjk = torch.norm(rjk)
gamma_ik = gamma[ik_sum]
cutoff_ik = cutoff[ik_sum]
E3 = calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
energy = energy + E3
F_comp[:] = F3[0]/norm_rij * rij
F[i, :] = F[i, :] + F_comp
F[nli[j], :] = F[nli[j], :] - F_comp
F_comp[:] = F3[1]/norm_rik * rik
F[i, :] = F[i, :] + F_comp
F[nli[k], :] = F[nli[k], :] - F_comp
F_comp[:] = F3[2]/norm_rjk * rjk
F[nli[j], :] = F[nli[j], :] + F_comp
F[nli[k], :] = F[nli[k], :] - F_comp
return energy, F
# =============================================================================
class StillingerWeberLayer(nn.Module):
"""
Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model
"""
def __init__(self):
super().__init__()
self.elements = elements
...
def forward(self,
elements: List[List[int]],
coords: torch.Tensor,
nl: List[List[int]],
padding: List[int]
):
total_conf_energy = torch.tensor(0.0)
n_atom = len(nl)
F = torch.zeros((n_atom, 3))
total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A,
self.B, self.p, self.q, self.sigma, self.gamma,
self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
F[:n_atom] = forces[:n_atom]
if len(padding) != 0:
pad_forces = forces[n_atom:]
n_padding = len(pad_forces)
if n_atom < n_padding:
for i in range(n_atom):
indices = torch.where(padding == i)
F[i] = F[i] + torch.sum(pad_forces[indices], 0)
else:
for f, org_index in zip(pad_forces, padding):
F[org_index] = F[org_index] + f
return total_conf_energy, F
# ==========================================================================================================
纯python版本的实现完全相同
# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines
def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
if rij < cutoff:
sig_r = sigma / rij
one_by_delta_r = 1.0 / (rij - cutoff)
Bpq = (B * sig_r ** p - sig_r ** q)
exp_sigma = np.exp(sigma * one_by_delta_r )
E2 = A * Bpq * exp_sigma
F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
F = F * (1./sigma) * A * exp_sigma
else:
return 0.0, 0.0
return E2, F
def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
if ((rij > cutoff_ij) or
(rik > cutoff_ik) or
(rjk > cutoff_jk)):
dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
return 0.0
else:
cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
cos_diff = cos_beta_ikj - cos_beta0
exp_ij_ik = np.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))
dij = - gamma_ij/(rij - cutoff_ij)**2
dik = - gamma_ik/(rik - cutoff_ik)**2
E3 = lam * exp_ij_ik * cos_diff ** 2
dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
dcos_drjk = (- rjk) / (rij * rik)
dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
return E3
def energy_and_forces(nl, elements_nl, coords_all, A, B, p, q, sigma, gamma,
cutoff, lam, cos_beta0, cutoff_jk):
"""
Calculatd Energy for a given list of coordiates, assuming first coordinate
to be of query atom i, and remaining in the list to be neighbours.
"""
energy = 0.0
F2 = 0.0
F3 = np.zeros(3)
E2 = 0.0
E3 = 0.0
gamma_ij = 0.0
gamma_ik = 0.0
cutoff_ij = 0.0
cutoff_ik = 0.0
xyz_i = np.zeros(3)
xyz_j = np.zeros(3)
xyz_k = np.zeros(3)
rij = np.zeros(3)
rik = np.zeros(3)
rjk = np.zeros(3)
F = np.zeros_like(coords_all)
F_comp = np.zeros(3)
for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
num_elem = len(nli)
xyz_i = coords_all[nli[0]]
elem_i = elements[0]
for j in range(1, num_elem):
elem_j = elements[j]
xyz_j = coords_all[nli[j]]
rij = xyz_j - xyz_i
norm_rij = np.sqrt(rij[0]**2 + rij[1]**2 + rij[2]**2)
# if elem_i == elem_j:
ij_sum = elem_i + elem_j
E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
energy = 0.5 * E2
F_comp = 0.5 * F2/norm_rij * rij
F[i,:] = F[i,:] + F_comp
F[nli[j], :] = F[nli[j],:] - F_comp
gamma_ij = gamma[ij_sum]
cutoff_ij = cutoff[ij_sum]
for k in range(j + 1, num_elem):
elem_k = elements[k]
if (elem_i != elem_j) and \
(elem_j == elem_k):
ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
ik_sum = elem_i + elem_k
xyz_k = coords_all[nli[k]]
rik = xyz_k - xyz_i
norm_rik = np.sqrt(rik[0]**2 + rik[1]**2 + rik[2]**2)
rjk = xyz_k - xyz_j
norm_rjk = np.sqrt(rjk[0]**2 + rjk[1]**2 + rjk[2]**2)
gamma_ik = gamma[ik_sum]
cutoff_ik = cutoff[ik_sum]
E3 = calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
energy = energy + E3
F_comp[:] = F3[0]/norm_rij * rij
F[i, :] = F[i, :] + F_comp
F[nli[j], :] = F[nli[j], :] - F_comp
F_comp[:] = F3[1]/norm_rik * rik
F[i, :] = F[i, :] + F_comp
F[nli[k], :] = F[nli[k], :] - F_comp
F_comp[:] = F3[2]/norm_rjk * rjk
F[nli[j], :] = F[nli[j], :] + F_comp
F[nli[k], :] = F[nli[k], :] - F_comp
return energy, F
# =============================================================================
class StillingerWeberLayer():
"""
Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model
"""
def __init__(self):
super().__init__()
self.elements = elements
...
def __call__(self, elements, coords, nl, padding):
n_atom = len(nl)
F = np.zeros((n_atom, 3))
total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A,
self.B, self.p, self.q, self.sigma, self.gamma,
self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
F[:n_atom] = forces[:n_atom]
if len(padding) != 0:
pad_forces = forces[n_atom:]
n_padding = len(pad_forces)
if n_atom < n_padding:
for i in range(n_atom):
indices = np.where(padding == i)
F[i] = F[i] + np.sum(pad_forces[indices], axis=0)
else:
for f, org_index in zip(pad_forces, padding):
F[org_index] = F[org_index] + f
return total_conf_energy, F
# ==========================================================================================================
解决方案
推荐阅读
- python-3.x - 在 selenium 上使用 BeautifulSoup 和 Geckodriver 有什么区别?
- c# - 访问属性属性中的类成员
- javascript - 如何为 jQuery 事件附加一些数据以进行冒泡?
- python - 为什么压缩函数和可迭代的结果不起作用?
- javascript - 无法清除 Gatsby 静态站点缓存
- c - 如何摆脱我从给定代码中得到的分段错误?
- c# - 在 C# 中以字节存储位的更优雅方式?
- c# - 将控件绑定到可观察集合列表中的特定项目
- java - IntelliJ Idea 完成错误
- java - 如何更新列表视图的 Java(Android) 数组中的值(简单的购物清单应用程序)