python - 如何避免 Numpy 类型转换?
问题描述
是否可以避免或发出从整数和32 bit float arrays
到的自动 Numpy 类型转换的警告64 bit float arrays
?
我的用例是我正在开发一个大型分析包(20k 行 Python 和 Numpy),目前混合了 float 32 和 64 以及一些 int dtypes,很可能导致次优性能和浪费内存,基本上我想在任何地方都使用 float32 。
我知道在Tensorflow中组合两个不同 dtype 的数组会产生错误——正是因为隐式转换为 float64 会导致性能下降,并且在所有计算的张量上都是“传染性的”,如果隐式完成,很难找到它的引入位置。
寻找 Numpy 中的选项或猴子修补 Numpy 的方法,使其在这方面的行为类似于 Tensorflow,即在诸如 等操作的隐式类型转换时发出错误np.add
,np.mul
或者甚至更好地发出带有打印回溯的警告,所以执行继续,但我看到它发生在哪里。可能的?
解决方案
免责声明:我没有以任何认真的方式对此进行测试,但这似乎是一条有前途的路线。
操纵 ufunc 行为的一种相对轻松的方式似乎是子类ndarray
化和覆盖__array_ufunc__
。例如,如果您满足于捕捉任何产生的东西float64
class no64(np.ndarray):
def __array_ufunc__(self, ufunc, method, *inputs, **kwds):
ret = getattr(ufunc, method)(*map(np.asarray,inputs), **kwds)
# some ufuncs return multiple arrays:
if isinstance(ret,tuple):
if any(x.dtype == np.float64 for x in ret):
raise ValueError
return (*(x.view(no64) for x in ret),)
if ret.dtype == np.float64:
raise ValueError
return ret.view(no64)
x = np.arange(6,dtype=np.float32).view(no64)
现在让我们看看我们的类可以做什么:
x*x
# no64([ 0., 1., 4., 9., 16., 25.], dtype=float32)
np.sin(x)
# no64([ 0. , 0.84147096, 0.9092974 , 0.14112 , -0.7568025 ,
# -0.9589243 ], dtype=float32)
np.frexp(x)
# (no64([0. , 0.5 , 0.5 , 0.75 , 0.5 , 0.625], dtype=float32), no64([0, 1, 2, 2, 3, 3], dtype=int32))
现在让我们将它与 64 位参数配对:
x + np.arange(6)
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "<stdin>", line 9, in __array_ufunc__
# ValueError
np.multiply.outer(x, np.arange(2.0))
# Traceback (most recent call last):
# File "<stdin>", line 1, in <module>
# File "<stdin>", line 9, in __array_ufunc__
# ValueError
什么不起作用(我相信还有更多)
np.outer(x, np.arange(2.0)) # not a ufunc, so slips through
# array([[0., 0.],
# [0., 1.],
# [0., 2.],
# [0., 3.],
# [0., 4.],
# [0., 5.]])
__array_function__
似乎是什么抓住了那些。
推荐阅读
- c++ - C++ 中有没有办法保存一个对象,然后重新加载它,尽管它有一个指针作为属性?
- c++ - Doxygen 关于 C++ 代码中重载运算符&= 的警告
- scala - 如何在 GraphX 中使用 2 步连接计算入度
- r - 剂量反应曲线 - 错误的 x 线间距
- php - 从另一个 URL 返回后,选择选项不会保留
- google-apps-script - 在 Googles Sheet 宏语言中移动“活动”单元格
- r - Shiny 可以打印/添加/绘制从输入中提取的值到绘图(nomogam)吗?
- c# - 带索引的 OpenGL / OpenTK 绘图:尝试读取或写入受保护的内存问题
- javascript - 如何使用 D3 在现有时间序列图中将线的形状更改为闭合条
- swift - 在 CoreData 中排序(swift)