python-3.x - 我希望我的 Python 函数接受 numpy ndarrays
问题描述
我正在编写一个脚本来绘制一些热力学属性的 3D 表示,Z = f(pr, Tr)。pr 和 Tr 是用创建的[numpy.]arange()
,然后用[numpy.]meshgrid()
以下方式映射它们:
Tr = arange(1.0, 2.6, 0.10)
pr = arange(0.5, 9.0, 0.25)
Xpr, YTr = meshgrid(pr, Tr)
然后将 Xpr 和 YTr 传递给计算上述属性的函数:
Z = function(Xpr, YTr)
(“函数”只是一个通用名称,稍后将替换为实际函数名称)。
最终绘制存储在 Z 中的值:
fig = plt.figure(1, figsize=(7, 6))
ax = fig.add_subplot(projection='3d')
surf = ax.plot_surface(Xpr, YTr, Z, cmap=cm.jet, linewidth=0, antialiased=False)
当“功能”非常简单时,一切正常,例如:
def zshell(pr_, Tr_):
A = -0.101 - 0.36*Tr_ + 1.3868*sqrt(Tr_ - 0.919)
B = 0.021 + 0.04275/(Tr_ - 0.65)
E = 0.6222 - 0.224*Tr_
F = 0.0657/(Tr_ - 0.85) - 0.037
G = 0.32*exp(-19.53*(Tr_ - 1.0))
D = 0.122*exp(-11.3*(Tr_ - 1.0))
C = pr_*(E + F*pr_ + G*pr_**4)
z = A + B*pr_ + (1.0 - A)*exp(-C) - D*(pr_/10.0)**4
return z
但是当函数是这样的时候它会失败:
def zvdw(pr_, Tr_):
A = 0.421875*pr_/Tr_**2 # 0.421875 = 27.0/64.0
B = 0.125*pr_/Tr_ # 0.125 = 1.0/8.0
z = 9.5e-01
erro = 1.0
while erro >= 1.0e-06:
c2 = -(B + 1.0)
c1 = A
c0 = -A*B
f = z**3 + c2*z**2 + c1*z + c0
df = 3.0e0*z**2 + 2.0e0*c2*z + c1
zf = z - f/df
erro = abs((zf - z)/z)
z = zf
return z
我强烈怀疑失败是由函数内部的迭代方法引起的 zvdw(pr_, Tr_)
(zvdw
之前已经过测试,并且在将浮点参数传递给它时工作得很好)。那是我得到的错误信息:
Traceback (most recent call last):
File "/home/fausto/.../TestMesh.py", line 81, in <module>
Z = zvdw(Xpr, YTr)
File "/home/fausto/.../TestMesh.py", line 63, in zvdw
while erro >= 1.0e-06:
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
然而,错误消息似乎(对我来说)与while
声明没有直接关系。
有任何想法吗?
解决方案
Tl;博士
有一个numpy
功能。你只需更换
def zvdw(pr_, Tr_):
经过
@np.vectorize
def zvdw(pr_, Tr_):
它有效。
更快获得
不幸的是,生成的图片看起来很丑,因为你的网格很稀疏。TR
我用和替换了你的步长pr
。不幸的是,我们在这里遇到了np.vectorize
. 从 numpy 文档
提供矢量化功能主要是为了方便,而不是为了性能。该实现本质上是一个 for 循环。
即它非常缓慢。已经在 0.0010 和 0.0025 时花了 17.5 秒。因此,每小 10 倍是不现实的,因为这将花费大约 100 倍的时间。幸运的是,您的代码很简单,我可以@numba.vectorize
在我的机器上使用它,它的速度提高了约 23 倍。
请注意,这仅适用于某些 python 代码。Numba 将 python 代码编译为 llvm 代码,因此它可以快速运行。但它无法为任意 python 代码做到这一点。见https://numba.pydata.org/numba-doc/dev/user/vectorize.html
代码似乎不适用于np.vectorize/numba.vectorize
这很奇怪。这是在我的机器上工作的代码的文字复制粘贴:
import numpy as np
import matplotlib.pyplot as plt
Tr = np.arange(1.0, 2.6, 0.10)
pr = np.arange(0.5, 9.0, 0.25)
Xpr, YTr = np.meshgrid(pr, Tr)
def zshell(pr_, Tr_):
A = -0.101 - 0.36*Tr_ + 1.3868*sqrt(Tr_ - 0.919)
B = 0.021 + 0.04275/(Tr_ - 0.65)
E = 0.6222 - 0.224*Tr_
F = 0.0657/(Tr_ - 0.85) - 0.037
G = 0.32*exp(-19.53*(Tr_ - 1.0))
D = 0.122*exp(-11.3*(Tr_ - 1.0))
C = pr_*(E + F*pr_ + G*pr_**4)
z = A + B*pr_ + (1.0 - A)*exp(-C) - D*(pr_/10.0)**4
return z
@np.vectorize
def zvdw(pr_, Tr_):
A = 0.421875*pr_/Tr_**2 # 0.421875 = 27.0/64.0
B = 0.125*pr_/Tr_ # 0.125 = 1.0/8.0
z = 9.5e-01
erro = 1.0
while erro >= 1.0e-06:
c2 = -(B + 1.0)
c1 = A
c0 = -A*B
f = z**3 + c2*z**2 + c1*z + c0
df = 3.0e0*z**2 + 2.0e0*c2*z + c1
zf = z - f/df
erro = abs((zf - z)/z)
z = zf
return z
Z = zvdw(Xpr, YTr)
fig = plt.figure(1, figsize=(7, 6))
ax = fig.add_subplot(projection='3d')
surf = ax.plot_surface(Xpr, YTr, Z, cmap=plt.cm.jet, linewidth=0, antialiased=False)
推荐阅读
- html - Thymeleaf:未找到并应用 css 路径
- opengl - OpenGL的无损纹理压缩
- ios - SearchBar 没有将我的搜索返回到联系人
- angular - 使用 URL 或路径模式托管子域?
- django - 如何修复heroku“应用程序错误”,其中控制台错误“code = H14 desc =“没有运行Web进程”
- c - 在 C 中打印一个反转的字符串/数组
- python - 生成具有年增长率的未来数据框
- node.js - 批处理文件执行后如何自动关闭cmd窗口永久运行js后台?
- c# - 检索复合键表数据的 Linq 查询不成功
- javascript - 如何从 HTML 元素中获取换行符