python - 尝试比较数组但得到“ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()”
问题描述
有人可以帮我理解和纠正标题中错误的短代码吗?
import matplotlib.pyplot as plt
import numpy as np
x = np.linspace(0, 1, 100)
a = 2
def f(x):
if x<0.5:
return a*x
elif x>=0.5:
return a*(1-x)
plt.plot(x, f(x))
plt.show()
6
7 def f(x):
----> 8 if x<0.5:
9 return a*x
10 elif x>=0.5:
编辑:
该图应该看起来像一个倒置的 V。尝试 np.piecewise():
x = np.linspace(0, 1, 100)
a = 2
y = np.piecewise(x, [x < 0.5, x >= 0.5], [a*x, a*(1-x)])
plt.plot(x, y)
plt.show()
我收到错误消息:NumPy 布尔数组索引分配无法将 100 个输入值分配给掩码为真的 50 个输出值。
解决方案
和x
2 个派生数组:
In [249]: x = np.linspace(0,1,11)
In [250]: x
Out[250]: array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])
In [251]: 2*x
Out[251]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 1.2, 1.4, 1.6, 1.8, 2. ])
In [252]: 2*(1-x)
Out[252]: array([2. , 1.8, 1.6, 1.4, 1.2, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
将两者合成的一种方法是从一个数组开始,然后用另一个数组替换值的子集:
In [253]: res = 2*(1-x)
In [254]: res[x<.5] = 2*x[x<.5]
In [255]: res
Out[255]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
np.where
可以用来做同样的事情:
In [256]: np.where(x<.5, 2*x, 2*(1-x))
Out[256]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
我们可以使用您的f
, 但带有标量参数:
In [261]: np.array([f(i) for i in x])
Out[261]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
要使用piecewise
我们必须提供函数,而不是整个数组:
In [266]: np.piecewise(x, [x < 0.5, x >= 0.5], [lambda i:a*i, lambda i:a*(1-i)])
Out[266]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
至于你原来的错误:
In [267]: x<.5
Out[267]:
array([ True, True, True, True, True, False, False, False, False,
False, False])
In [268]: if (x<.5): print(1)
Traceback (most recent call last):
File "<ipython-input-268-f1475f5112fc>", line 1, in <module>
if (x<.5): print(1)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
if
是一个普通的python表达式。没有隐含的元素迭代x
。 if (x<.5).all():...
运行,但仅测试 的所有元素是否x
满足条件。那不是你想要的。同样对于any
。我一开始做的那种x<.5
掩蔽可能是一个更常见的解决方案ambiguity
。
更直接的转换转换f(x)
是计算mask
并应用它,以及它的补码以创建结果的两个部分:
In [269]: mask = x<.5
In [270]: res = np.zeros_like(x)
In [271]: res
Out[271]: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [272]: res[mask] = 2*x[mask]
In [273]: res[~mask] = 2*(1-x[~mask])
In [274]: res
Out[274]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
或者
In [275]: np.hstack((2*x[mask], 2*(1-x[~mask])))
Out[275]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])
推荐阅读
- mysql - MySQL - SSL 是必需的,但服务器不支持它
- r - 使用 xelatex 在 PDF 中包含图形
- vb.net - 一个图片框是否有可能检测到另一个图片框中的颜色并避免它?
- java - 注册表单在提交时引发内部服务器错误
- r - 如何使用 R 中的多元回归预测新随机生成数据集的新变量?
- containers - 输入在 slurm 作业中运行的奇异容器
- oracle - Oracle SQL Developer 调试器自动提交问题
- bbedit - BBEdit 的排水沟区域中的那些横向“L”是什么意思?
- python - 不适合矢量化的循环。我们可以让它更快吗?
- php - PHP - 清理主题标签的输入,允许阿拉伯语、希伯来语、日语等和表情符号?