首页 > 解决方案 > MATLAB 转换为 Python 进行循环转换

问题描述

我正在尝试将一些代码从 MATLAB 转换为 python,但遇到了问题。MATLAB 中的 for 循环生成 (n^grid,2,2) 张量的元素,但在转换为 python 后,不会生成任何元素,并且返回一个填充为零的张量。代码如下:

import scipy as sp
import numpy as np
import scipy.special as scl
import matplotlib.pyplot as plt
import numpy.matlib

###
#Constants
###
omega  = np.array([2,2]) 
a = np.array([3,0]) 
couple = 0.05*np.array([0,3]) 
dE12 = 0.0 
###
#Derived Parameters
ws = omega**2
a11 = np.array([0,0])
a22 = a
ac=  np.array([0,0])
ac[0] = dE12/ws[0]/a[0]+a[0]/2
ac[1] = 0.0
###
#Parameters(nbas<= ngrid+1)
nbas = 54
ngrid  = 56
###
#Gauss-Hermite Quadrature
###
#Gauss-Hermite-Quadrature
#Hermite matrix
def hermipol(n):
    p = np.zeros((n+1,n+1))
    p[0][0] = 1
    if n == 0:
        p = np.array([[1,0],[2,0]], dtype = float)
    if n > 0:
        p[1][range(0,2)] = np.array([2,0])
        if n >=1:
            for k in range(2,n+1):
                p[k][range(n)] = 2*p[k-1][range(0,n)]
                p[k][range(2,n+1)] += -2*(k-1)*p[k-2][range(0,n-1)]
    for i in range(0,n+1):
        p[i,:] /= np.sqrt(np.sqrt(np.pi)*2**(i)*scl.factorial(i))

    return(p)
pp =hermipol(ngrid)
#Generation Gauss-Hermite Quadrature nodes and weights
def ghquad(n):
    return(np.polynomial.hermite.hermgauss(n))
[nodes,weights]  = ghquad(ngrid)

nodes = nodes.reshape(56,1)
weights = weights.reshape(56,1)
###
y = nodes/np.sqrt(omega[1])
x = nodes/np.sqrt(omega[0])


#Potential Matrices
xx= np.matlib.repmat(x,1,ngrid)
yy = np.matlib.repmat(y.T,ngrid,1)

V11 = (ws[0]*(xx-a11[0])**2+ws[1]*(yy-a11[1])**2)/2
V22 = (ws[0]*(xx-a22[0])**2+ws[1]*(yy-a22[1])**2)/2 + dE12
V12 = couple[0]*(xx-ac[0]) + couple[1]*(yy-ac[1])

dia2adi = np.zeros((ngrid**2,2,2))
V11 = V11.reshape(ngrid**2,1)
V22 = V22.reshape(ngrid**2,1)
V12 = V12.reshape(ngrid**2,1)

for ii in range(ngrid**2):
    mm = np.array([[V11[ii],V12[ii]],[V12[ii],V22[ii]]]).reshape(2,2)
    [e,u] = np.linalg.eigh(mm)
    e = e.reshape(2,1)
    ind = e.argsort(axis = 0)
    dia2adi[ii:ii,:,:] = u[ii:ii, ind]
print(dia2adi)

代码有点长,但重要的是最后的 for 循环。我试图重新创建的 MATLAB for 循环是:

dia2adi = zeros(2,2, ngrid^2);
for ii  = 1:ngrid^2
    mm = [V11(ii), V12(ii); V12(ii), V22(ii)];
    [u, e] = eig(mm);
    e = diag(e); [e, ind] = sort(e);
    dia2adi(:,:, ii) = u(:, ind);

非常感谢任何指导,谢谢。

标签: pythonmatlab

解决方案


推荐阅读