首页 > 解决方案 > 如何在特定条件下将 nan 替换为 1

问题描述

在我的程序中,我有一些数组,我使用一个简单的公式来计算一个值。我正在使用的代码

from itertools import combinations
import numpy as np

res = [
    np.array([[12.99632095], [29.60571445], [-1.85595153], [68.78926787], [2.75185088], [2.75204384]]),
    np.array([[15.66458062], [0], [-3.75927882], [0], [2.30128711], [197.45459974]]),
    np.array([[10.66458062], [0], [0], [-2.65954113], [-2.30128711], [197.45459974]]),
]


def cal():
    pairs = combinations(res, 2)
    results = []
    for pair in pairs:
        r = np.concatenate(pair, axis=1)
        r1 = r[:, 0]
        r2 = r[:, 1]
        sign = np.sign(r1 * r2)
        result = np.multiply(sign, np.min(np.abs(r), axis=1) / np.max(np.abs(r), axis=1))

        results.append(result)
    return results

我得到的输出是

RuntimeWarning: invalid value encountered in true_divide
  result = np.multiply(sign, np.min(np.abs(r), axis=1) / np.max(np.abs(r), axis=1))
[array([0.82966287, 0.        , 0.49369882, 0.        , 0.83626883,
       0.0139376 ]), array([ 0.82058458,  0.        ,  0.        , -0.03866215, -0.83626883,
        0.0139376 ]), array([ 0.68080856,         nan,  0.        ,  0.        , -1.        ,
        1.        ])]

我在这里,我正在获取nan第三个输出数组。我明白了,我是nan因为0/0.

由于数组的大小或位置0不固定。所以,我想以这种方式更改代码,如果我得到0/0, 在这里,而不是nan我想保存1.

你能告诉我我该如何处理nan吗?

标签: python-3.xnumpy

解决方案


可能的解决方案之一:

import itertools as it

def cal():
    pairs = it.combinations(res, 2)
    rv = []
    for pair in pairs:
        r = np.concatenate(pair, axis=1)
        sign = np.sign(np.prod(r, axis=1))
        t1 = np.min(np.abs(r), axis=1)
        t2 = np.max(np.abs(r), axis=1)
        ratio = np.full(shape=r.shape[0], fill_value=1.)
        np.divide(t1, t2, out=ratio, where=np.not_equal(t2, 0.))
        wrk = np.full(shape=r.shape[0], fill_value=1.)
        np.multiply(sign, ratio, out=wrk, where=np.not_equal(t2, 0.))
        rv.append(wrk)
    return rv

而不是np.sign(r1 * r2)我用np.sign(np.prod(r, axis=1)).

然后设置默认值而不是NaN的技巧是,我创建了一个填充了这个默认值的数组,并调用了np.dividepass outwhere只在除数不为0的情况下进行除法。

最后一步是np.multiply,与where条件相同。

要测试此代码并漂亮地打印结果,您可以运行:

with np.printoptions(formatter={'float': '{: 9.5f}'.format}):
    result = cal()
    for tbl in result:
        print(tbl)

结果是:

[  0.82966   0.00000   0.49370   0.00000   0.83627   0.01394]
[  0.82058   0.00000   0.00000  -0.03866  -0.83627   0.01394]
[  0.68081   1.00000   0.00000   0.00000  -1.00000   1.00000]

如您所见,在最后一种情况下,有 2 个1.0值,对应于第二个和第三个源数组中的相同值。


推荐阅读