首页 > 解决方案 > 尝试比较数组但得到“ValueError:具有多个元素的数组的真值不明确。使用 a.any() 或 a.all()”

问题描述

有人可以帮我理解和纠正标题中错误的短代码吗?

import matplotlib.pyplot as plt
import numpy as np

x = np.linspace(0, 1, 100)
a = 2

def f(x):
    if x<0.5:
        return a*x
    elif x>=0.5:
        return a*(1-x)

plt.plot(x, f(x))
plt.show()
      6 
      7 def f(x):
----> 8     if x<0.5:
      9         return a*x
     10     elif x>=0.5:

编辑:

该图应该看起来像一个倒置的 V。尝试 np.piecewise():

x = np.linspace(0, 1, 100)
a = 2

y = np.piecewise(x, [x < 0.5, x >= 0.5], [a*x, a*(1-x)])

plt.plot(x, y)
plt.show()

我收到错误消息:NumPy 布尔数组索引分配无法将 100 个输入值分配给掩码为真的 50 个输出值。

标签: pythonnumpymatplotlibplot

解决方案


x2 个派生数组:

In [249]: x = np.linspace(0,1,11)
In [250]: x
Out[250]: array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ])
In [251]: 2*x
Out[251]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 1.2, 1.4, 1.6, 1.8, 2. ])
In [252]: 2*(1-x)
Out[252]: array([2. , 1.8, 1.6, 1.4, 1.2, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

将两者合成的一种方法是从一个数组开始,然后用另一个数组替换值的子集:

In [253]: res = 2*(1-x)
In [254]: res[x<.5] = 2*x[x<.5]
In [255]: res
Out[255]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

np.where可以用来做同样的事情:

In [256]: np.where(x<.5, 2*x, 2*(1-x))
Out[256]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

我们可以使用您的f, 但带有标量参数:

In [261]: np.array([f(i) for i in x])
Out[261]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

要使用piecewise我们必须提供函数,而不是整个数组:

In [266]: np.piecewise(x, [x < 0.5, x >= 0.5], [lambda i:a*i, lambda i:a*(1-i)])
Out[266]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

至于你原来的错误:

In [267]: x<.5
Out[267]: 
array([ True,  True,  True,  True,  True, False, False, False, False,
       False, False])
In [268]: if (x<.5): print(1)
Traceback (most recent call last):
  File "<ipython-input-268-f1475f5112fc>", line 1, in <module>
    if (x<.5): print(1)
ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

if是一个普通的python表达式。没有隐含的元素迭代xif (x<.5).all():...运行,但仅测试 的所有元素是否x满足条件。那不是你想要的。同样对于any。我一开始做的那种x<.5掩蔽可能是一个更常见的解决方案ambiguity

更直接的转换转换f(x)是计算mask并应用它,以及它的补码以创建结果的两个部分:

In [269]: mask = x<.5
In [270]: res = np.zeros_like(x)
In [271]: res
Out[271]: array([0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.])
In [272]: res[mask] = 2*x[mask]
In [273]: res[~mask] = 2*(1-x[~mask])
In [274]: res
Out[274]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

或者

In [275]: np.hstack((2*x[mask], 2*(1-x[~mask])))
Out[275]: array([0. , 0.2, 0.4, 0.6, 0.8, 1. , 0.8, 0.6, 0.4, 0.2, 0. ])

推荐阅读