python - 我必须使用哪些数据类型在 Cython 中编写此函数?
问题描述
我发现了这段不错的代码,我想用 Cython 重写它来学习 Cython。此外,我想绘制 python for loop、numpy 和 cython 之间的速度比较增益。这就是我目前对大约 40k 行的比较结果:
def sliding_window_slicing(a, no_items, item_type=0):
"""This method perfoms sliding window slicing of numpy arrays
Parameters
----------
a : numpy
An array to be slided in subarrays
no_items : int
Number of sliced arrays or elements in sliced arrays
item_type: int
Indicates if no_items is number of sliced arrays (item_type=0) or
number of elements in sliced array (item_type=1), by default 0
Return
------
numpy
Sliced numpy array
"""
if item_type == 0:
no_slices = no_items
no_elements = len(a) + 1 - no_slices
if no_elements <=0:
raise ValueError('Sliding slicing not possible, no_items is larger than ' + str(len(a)))
else:
no_elements = no_items
no_slices = len(a) - no_elements + 1
if no_slices <=0:
raise ValueError('Sliding slicing not possible, no_items is larger than ' + str(len(a)))
subarray_shape = a.shape[1:]
shape_cfg = (no_slices, no_elements) + subarray_shape
strides_cfg = (a.strides[0],) + a.strides
as_strided = np.lib.stride_tricks.as_strided #shorthand
return as_strided(a, shape=shape_cfg, strides=strides_cfg)
我试过这个,但我无法决定我必须使用的数据类型:
%%cython -a
#cython: boundscheck=False, wraparound=False, cdivision=True, nonecheck=False
import numpy as np
cimport numpy as np
cimport cython
cpdef np.ndarray cython_sliding_window_slicing(long[:,:] a, long no_items, int item_type=0):
cdef:
long no_slices, no_elements
long subarray_shape
(long,long) shape_cfg
np.ndarray strides_cfg
if item_type == 0:
no_slices = no_items
no_elements = len(a) + 1 - no_slices
else:
no_elements = no_items
no_slices = len(a) - no_elements + 1
subarray_shape = a.shape[1]
shape_cfg = (no_slices, no_elements) + (subarray_shape,0)
strides_cfg = (a.strides[0],) + a.strides
return np.lib.stride_tricks.as_strided(a, shape=shape_cfg, strides=strides_cfg)
在我收到错误的那一刻:Expected a tuple of size 2, got tuple
我在函数上使用的数组是这样的:example_array = np.random.randint(1000,size=(40, 8))
我希望有人可以帮助并向我解释我需要更改什么以及如何调试脚本(使用 jupyter notebook?)。
谢谢你的帮助!
解决方案
推荐阅读
- php - 终端中的作业处理编号 - laravel 中的队列
- awk - 如何让 awk 打印变量而不是完整匹配的行?
- python - ValueError:int() 的无效文字,基数为 10:'D1'
- json - JSON中的Powershell嵌套文件夹循环
- c++ - C++ DLL 行为在 py3.6 和 py3.8 之间有所不同:SIGSEGV "__gnu_cxx::__exchange_and_add" 与 Python 3.6.6
- c++ - 使用 cmake 构建 C++ Protobuf 时出错
- algorithm - 了解屏幕上像素的显示
- google-cloud-platform - 我们如何在google cloud deploy中使用cloud build privatePool
- c - 如何创建 n 个线程,每个线程创建 n - 1 个线程
- spring-boot - Spring WebClient - 如果在 doOnError 中引发异常,则停止重试