首页 > 解决方案 > 如何覆盖 NumPy 的 ndarray 隐式布尔值?(即 __bool__ 函数)

问题描述

在争论开始之前,我将解释一个用例。tl; dr:一个只读库检查隐式布尔值。

我已经覆盖了 Keras 的 train_step 只接受一个参数(对于 GAN)。这对于训练数据很好。但是,当我添加验证数据时,Keras 有一个检查:

if validation_data:
  val_x, val_y, val_sample_weight = (
    data_adapter.unpack_x_y_sample_weight(validation_data))

NumPy 当然会提高

ValueError: The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()

通常,validation_data 是一个列表、元组等,这个检查工作正常。我不想更改库的源代码,所以我最好的想法是覆盖 numpy.ndarray 的隐式布尔值

但是,子类化

class np_array_falsifiable(np.ndarray):
    def __bool__(self):
        return self.size

提高

ValueError: maximum supported dimension for an ndarray is 32, found 6301

我将不胜感激并接受任何解决此问题的答案。

标签: pythonnumpytensorflowkerasnumpy-ndarray

解决方案


推荐阅读