首页 > 解决方案 > Pytorch模型比纯python函数慢10倍

问题描述

我正在尝试将 Stillinger Weber 潜力实现为 pytorch 层。我知道这样的层对于像 Pytorch 这样的 ML 框架可能不是“理想的”,但无论如何我希望它的性能不会比纯 python 或 python/numpy 的实现差。

当我运行与 pytorch nn.Module 继承类相同的层时,我得到单次运行大约 27.977 +- 4.8341 秒(平均超过 10 次运行)。一个简单的 python 类在 2.4446 +- 0.052723 秒内执行完全相同的代码(平均超过 10 次运行)

完整的脚本可以在 GitHub https://github.com/ipcamit/temp_4_pytorch找到

我究竟做错了什么?

以下是pytorch的实现功能

# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines (PyTorch gets fussy if function are part of class)
@torch.jit.script
def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
    if rij < cutoff:
        sig_r = sigma / rij
        one_by_delta_r = 1.0 / (rij - cutoff)
        Bpq = (B * sig_r ** p - sig_r ** q)
        exp_sigma = torch.exp(sigma * one_by_delta_r )
        E2 = A * Bpq * exp_sigma 
        F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
        F = F * (1./sigma) * A * exp_sigma
    else:
        return torch.tensor(0.0), torch.tensor(0.0) 
    return E2, F


@torch.jit.script
def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
                 cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
    if ((rij > cutoff_ij) or 
        (rik > cutoff_ik) or 
        (rjk > cutoff_jk)):
        dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
        return torch.tensor(0.0)
    else: 
        cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
        cos_diff = cos_beta_ikj - cos_beta0

        exp_ij_ik = torch.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))

        dij = - gamma_ij/(rij - cutoff_ij)**2
        dik = - gamma_ik/(rik - cutoff_ik)**2

        E3 = lam * exp_ij_ik * cos_diff ** 2

        dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
        dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
        dcos_drjk = (- rjk) / (rij * rik)

        dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
        dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
        dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
    return E3


@torch.jit.script
def energy_and_forces(
    nl: List[List[int]],
    elements_nl: List[List[int]],
    coords_all,
    A: List[torch.Tensor],
    B: List[torch.Tensor],
    p: List[torch.Tensor],
    q: List[torch.Tensor],
    sigma: List[torch.Tensor],
    gamma: List[torch.Tensor],
    cutoff: List[torch.Tensor],
    lam: List[torch.Tensor],
    cos_beta0: List[torch.Tensor],
    cutoff_jk: List[torch.Tensor]
    ):
    """
    Calculatd Energy for a given list of coordiates, assuming first coordinate
    to be of query atom i, and remaining in the list to be neighbours. 
    """
    energy = torch.tensor(0.0)
    F2 = torch.tensor(0.0)
    F3 = torch.zeros(3)
    E2 = torch.tensor(0.0)
    E3 = torch.tensor(0.0)
    gamma_ij = torch.tensor(0.0)
    gamma_ik = torch.tensor(0.0)
    cutoff_ij = torch.tensor(0.0)
    cutoff_ik = torch.tensor(0.0)
    xyz_i = torch.zeros(3)
    xyz_j = torch.zeros(3)
    xyz_k = torch.zeros(3)
    rij = torch.zeros(3)
    rik = torch.zeros(3)
    rjk = torch.zeros(3)
    F = torch.zeros_like(coords_all)
    F_comp = torch.zeros(3)
    for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
        num_elem = len(nli)
        xyz_i = coords_all[nli[0]]
        elem_i = elements[0]
        for j in range(1, num_elem):
            elem_j = elements[j]
            xyz_j = coords_all[nli[j]]
            rij = xyz_j - xyz_i
            norm_rij = torch.norm(rij)
            # if elem_i == elem_j:
            ij_sum = elem_i + elem_j

            E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
            energy = 0.5 * E2
            F_comp =  0.5 * F2/norm_rij * rij
            F[i,:] = F[i,:] + F_comp
            F[nli[j], :] = F[nli[j],:] - F_comp
            gamma_ij = gamma[ij_sum]
            cutoff_ij = cutoff[ij_sum]

            for k in range(j + 1, num_elem):
                elem_k = elements[k]
                if (elem_i != elem_j) and \
                   (elem_j == elem_k):
                    ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
                    ik_sum = elem_i + elem_k
                    xyz_k = coords_all[nli[k]]
                    rik = xyz_k - xyz_i
                    norm_rik = torch.norm(rik)
                    rjk = xyz_k - xyz_j
                    norm_rjk = torch.norm(rjk)
                    gamma_ik = gamma[ik_sum]
                    cutoff_ik = cutoff[ik_sum]
                    E3 =  calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
                                    cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
                    energy = energy + E3
                    F_comp[:] = F3[0]/norm_rij * rij
                    F[i, :] = F[i, :] + F_comp
                    F[nli[j], :] = F[nli[j], :] - F_comp
                    F_comp[:] = F3[1]/norm_rik * rik
                    F[i, :] = F[i, :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
                    F_comp[:] = F3[2]/norm_rjk * rjk
                    F[nli[j], :] = F[nli[j], :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
    return energy, F


# =============================================================================
class StillingerWeberLayer(nn.Module):
    """
    Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model 
    """
    def __init__(self):
        super().__init__()
        self.elements = elements
...
    def forward(self, 
                elements: List[List[int]],
                coords: torch.Tensor,
                nl: List[List[int]],
                padding: List[int]
                ):
        total_conf_energy = torch.tensor(0.0)
        n_atom = len(nl)
        F = torch.zeros((n_atom, 3))
        total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A, 
                                            self.B, self.p, self.q, self.sigma, self.gamma,
                                            self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
        F[:n_atom] = forces[:n_atom]

        if len(padding) != 0:
            pad_forces = forces[n_atom:]
            n_padding = len(pad_forces)

            if n_atom < n_padding:
                for i in range(n_atom):
                    indices = torch.where(padding == i)
                    F[i] = F[i] + torch.sum(pad_forces[indices], 0)
            else:
                for f, org_index in zip(pad_forces, padding):
                    F[org_index] = F[org_index] + f
        return total_conf_energy, F
# ==========================================================================================================

纯python版本的实现完全相同

# =============================================================================
# StillingerWeber Model
# =============================================================================
# SW subroutines

def calc_d_sw2(A, B, p, q, sigma, cutoff, rij):
    if rij < cutoff:
        sig_r = sigma / rij
        one_by_delta_r = 1.0 / (rij - cutoff)
        Bpq = (B * sig_r ** p - sig_r ** q)
        exp_sigma = np.exp(sigma * one_by_delta_r )
        E2 = A * Bpq * exp_sigma 
        F = (q * sig_r ** (q + 1)) - p * B * sig_r ** (p + 1) - Bpq * (sigma * one_by_delta_r) ** 2
        F = F * (1./sigma) * A * exp_sigma
    else:
        return 0.0, 0.0 
    return E2, F


def calc_d_sw3(lam, cos_beta0, gamma_ij, gamma_ik,
                 cutoff_ij, cutoff_ik, cutoff_jk, rij, rik, rjk, dE3_dr):
    if ((rij > cutoff_ij) or 
        (rik > cutoff_ik) or 
        (rjk > cutoff_jk)):
        dE3_dr[0] = 0.0; dE3_dr[1] = 0.0; dE3_dr[2] = 0.0
        return 0.0
    else: 
        cos_beta_ikj = (rij**2 + rik**2 - rjk**2) / (2 * rij * rik)
        cos_diff = cos_beta_ikj - cos_beta0

        exp_ij_ik = np.exp(gamma_ij/(rij - cutoff_ij) + gamma_ik/(rik - cutoff_ik))

        dij = - gamma_ij/(rij - cutoff_ij)**2
        dik = - gamma_ik/(rik - cutoff_ik)**2

        E3 = lam * exp_ij_ik * cos_diff ** 2

        dcos_drij = (rij**2 - rik**2 + rjk**2) / (2 * rij**2 * rik)
        dcos_drik = (rik**2 - rij**2 + rjk**2) / (2 * rik**2 * rij)
        dcos_drjk = (- rjk) / (rij * rik)

        dE3_dr[0] = lam * cos_diff * exp_ij_ik * (dij * cos_diff + 2 * dcos_drij)
        dE3_dr[1] = lam * cos_diff * exp_ij_ik * (dik * cos_diff + 2 * dcos_drik)
        dE3_dr[2] = lam * cos_diff * exp_ij_ik * 2 * dcos_drjk
    return E3


def energy_and_forces(nl, elements_nl, coords_all, A, B, p, q, sigma, gamma, 
                        cutoff, lam, cos_beta0, cutoff_jk):
    """
    Calculatd Energy for a given list of coordiates, assuming first coordinate
    to be of query atom i, and remaining in the list to be neighbours. 
    """
    energy = 0.0
    F2 = 0.0
    F3 = np.zeros(3)
    E2 = 0.0
    E3 = 0.0
    gamma_ij = 0.0
    gamma_ik = 0.0
    cutoff_ij = 0.0
    cutoff_ik = 0.0
    xyz_i = np.zeros(3)
    xyz_j = np.zeros(3)
    xyz_k = np.zeros(3)
    rij = np.zeros(3)
    rik = np.zeros(3)
    rjk = np.zeros(3)
    F = np.zeros_like(coords_all)
    F_comp = np.zeros(3)
    for i, (nli, elements) in enumerate(zip(nl,elements_nl)):
        num_elem = len(nli)
        xyz_i = coords_all[nli[0]]
        elem_i = elements[0]
        for j in range(1, num_elem):
            elem_j = elements[j]
            xyz_j = coords_all[nli[j]]
            rij = xyz_j - xyz_i
            norm_rij = np.sqrt(rij[0]**2 + rij[1]**2 + rij[2]**2)
            # if elem_i == elem_j:
            ij_sum = elem_i + elem_j

            E2, F2 = calc_d_sw2(A[ij_sum], B[ij_sum], p[ij_sum], q[ij_sum], sigma[ij_sum], cutoff[ij_sum], norm_rij)
            energy = 0.5 * E2
            F_comp =  0.5 * F2/norm_rij * rij
            F[i,:] = F[i,:] + F_comp
            F[nli[j], :] = F[nli[j],:] - F_comp
            gamma_ij = gamma[ij_sum]
            cutoff_ij = cutoff[ij_sum]

            for k in range(j + 1, num_elem):
                elem_k = elements[k]
                if (elem_i != elem_j) and \
                   (elem_j == elem_k):
                    ijk_sum = 2 + -1 * (elem_i + elem_j + elem_k)
                    ik_sum = elem_i + elem_k
                    xyz_k = coords_all[nli[k]]
                    rik = xyz_k - xyz_i
                    norm_rik = np.sqrt(rik[0]**2 + rik[1]**2 + rik[2]**2)
                    rjk = xyz_k - xyz_j
                    norm_rjk = np.sqrt(rjk[0]**2 + rjk[1]**2 + rjk[2]**2)
                    gamma_ik = gamma[ik_sum]
                    cutoff_ik = cutoff[ik_sum]
                    E3 =  calc_d_sw3(lam[ijk_sum], cos_beta0[ijk_sum], gamma_ij, gamma_ik,
                                    cutoff_ij, cutoff_ik, cutoff_jk[ijk_sum], norm_rij, norm_rik, norm_rjk, F3)
                    energy = energy + E3
                    F_comp[:] = F3[0]/norm_rij * rij
                    F[i, :] = F[i, :] + F_comp
                    F[nli[j], :] = F[nli[j], :] - F_comp
                    F_comp[:] = F3[1]/norm_rik * rik
                    F[i, :] = F[i, :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
                    F_comp[:] = F3[2]/norm_rjk * rjk
                    F[nli[j], :] = F[nli[j], :] + F_comp
                    F[nli[k], :] = F[nli[k], :] - F_comp
    return energy, F


# =============================================================================
class StillingerWeberLayer():
    """
    Stillinger-Weber single species layer for Mo and S atom for use in PyTorch model
    """
    def __init__(self):
        super().__init__()
        self.elements = elements
...
    def __call__(self, elements, coords, nl, padding):
        n_atom = len(nl)
        F = np.zeros((n_atom, 3))
        total_conf_energy, forces = energy_and_forces(nl, elements, coords, self.A, 
                                            self.B, self.p, self.q, self.sigma, self.gamma,
                                            self.cutoff, self.lam, self.cos_beta0, self.cutoff_jk)
        F[:n_atom] = forces[:n_atom]

        if len(padding) != 0:
            pad_forces = forces[n_atom:]
            n_padding = len(pad_forces)

            if n_atom < n_padding:
                for i in range(n_atom):
                    indices = np.where(padding == i)
                    F[i] = F[i] + np.sum(pad_forces[indices], axis=0)
            else:
                for f, org_index in zip(pad_forces, padding):
                    F[org_index] = F[org_index] + f
        return total_conf_energy, F
# ==========================================================================================================

标签: pytorch

解决方案


推荐阅读