首页 > 解决方案 > Numba:如何将任意逻辑字符串解析为循环中的 jitclassed 实例序列

问题描述

Tl博士。如果我要简短地解释这个问题:

  1. 我有信号:
np.random.seed(42)
x = np.random.randn(1000)
y = np.random.randn(1000)
z = np.random.randn(1000)
  1. 和人类可读的字符串元组逻辑,如:
entry_sig_ = ((x,y,'crossup',False),)
exit_sig_ = ((x,z,'crossup',False), 'or_',(x,y,'crossdown',False))

在哪里:

  1. 输出通过以下方式生成:
@njit
def run(x, entry_sig, exit_sig):
    '''
    x: np.array
    entry_sig, exit_sig: homogeneous tuples of tuple signals
    Returns: sequence of 0 and 1 satisfying entry and exit sigs
    ''' 
    L = x.shape[0]
    out = np.empty(L)
    out[0] = 0.0
    out[-1] = 0.0
    i = 1
    trade = True
    while i < L-1:
        out[i] = 0.0
        if reduce_sig(entry_sig,i) and i<L-1:
            out[i] = 1.0
            trade = True
            while trade and i<L-2:
                i += 1
                out[i] = 1.0
                if reduce_sig(exit_sig,i):
                    trade = False
        i+= 1
    return out

reduce_sig(sig,i)是一个函数(见下面的定义),它解析元组并返回给定时间点的结果输出。

问题:

到目前为止,SingleSig对于任何给定的时间点,类的对象都是在 for 循环中从头开始实例化的;因此,没有“记忆”,这完全抵消了拥有一个类的优点,一个裸函数就可以了。是否存在解决方法(不同的类模板、不同的方法等),以便:

  1. 可以查询组合元组信号在特定时间点的值i
  2. “记忆”可以重置;ieegMultiSig(sig_tuple).memory_field可以设置为0组成信号电平。

标签: pythonnumba

解决方案


以下代码为信号添加了一个内存,可以使用MultiSig.reset()该内存将所有信号的计数重置为 0。可以使用该内存查询该内存MultiSig.query_memory(key)以​​返回当时该信号的命中数。

为了使记忆功能起作用,我必须在信号中添加唯一键来识别它们。

from numba import njit, int64, float64, types
from numba.types import Array, string, boolean
from numba import jitclass
import numpy as np

np.random.seed(42)
x = np.random.randn(1000000)
y = np.random.randn(1000000)
z = np.random.randn(1000000)

# Example of "human-readable" signals
entry_sig_ = ((x,y,'crossup',False),)
exit_sig_ = ((x,z,'crossup',False), 'or_',(x,y,'crossdown',False))

# Turn signals into homogeneous tuple
#entry_sig_
entry_sig = (((x,y,'crossup',False),'NOP','1'),)
#exit_sig_
exit_sig = (((x,z,'crossup',False),'or_','2'),((x,y,'crossdown',False),'NOP','3'))

@njit
def cross(x, y, i):
    '''
    x,y: np.array
    i: int - point in time
    Returns: 1 or 0 when condition is met
    '''
    if (x[i - 1] - y[i - 1])*(x[i] - y[i]) < 0:
        out = 1
    else:
        out = 0
    return out


kv_ty = (types.string,types.int64)

spec = [
    ('memory', types.DictType(*kv_ty)),
]

@njit
def single_signal(x, y, how, acc, i):
    '''
    i: int - point in time
    Returns either signal or accumulator
    '''
    if cross(x, y, i):
        if x[i] < y[i] and how == 'crossdown':
            out = 1
        elif x[i] > y[i] and how == "crossup":
            out = 1
        else:
            out = 0
    else:
        out = 0
    return out
    
@jitclass(spec)
class MultiSig:
    def __init__(self,entry,exit):
        '''
        initialize memory at single signal level
        '''
        memory_dict = {}
        for i in entry:
            memory_dict[str(i[2])] = 0
        
        for i in exit:
            memory_dict[str(i[2])] = 0
        
        self.memory = memory_dict
        
    def reduce_sig(self, sig, i):
        '''
        Parses multisignal
        sig: homogeneous tuple of tuples ("human-readable" signal definition)
        i: int - point in time
        Returns: resulting value of multisignal
        '''
        L = len(sig)
        out = single_signal(*sig[0][0],i)
        logic = sig[0][1]
        if out:
            self.update_memory(sig[0][2])
        for cnt in range(1, L):
            s = single_signal(*sig[cnt][0],i)
            if s:
                self.update_memory(sig[cnt][2])
            out = out | s if logic == 'or_' else out & s
            logic = sig[cnt][1]
        return out
    
    def update_memory(self, key):
        '''
        update memory
        '''
        self.memory[str(key)] += 1
    
    def reset(self):
        '''
        reset memory
        '''
        dicti = {}
        for i in self.memory:
            dicti[i] = 0
        self.memory = dicti
        
    def query_memory(self, key):
        '''
        return number of hits on signal
        '''
        return self.memory[str(key)]

@njit
def run(x, entry_sig, exit_sig):
    '''
    x: np.array
    entry_sig, exit_sig: homogeneous tuples of tuples
    Returns: sequence of 0 and 1 satisfying entry and exit sigs
    '''
    L = x.shape[0]
    out = np.empty(L)
    out[0] = 0.0
    out[-1] = 0.0
    i = 1
    multi = MultiSig(entry_sig,exit_sig)
    while i < L-1:
        out[i] = 0.0
        if multi.reduce_sig(entry_sig,i) and i<L-1:
            out[i] = 1.0
            trade = True
            while trade and i<L-2:
                i += 1
                out[i] = 1.0
                if multi.reduce_sig(exit_sig,i):
                    trade = False
        i+= 1
    return out

run(x, entry_sig, exit_sig)

重申我在评论中所说的,|并且&是位运算符,而不是逻辑运算符。1 & 2输出 0/False 这不是我认为您希望对其进行评估的结果,因此我确保outands只能为 0/1,以便产生预期的输出。

您知道这是因为:

out = out | s if logic == 'or_' else out & s

时间序列内部的顺序entry_sigexit_sig问题?

crossup让 (output, logic) 为元组,其中输出为 0 或 1 ,具体取决于如何crossdown评估元组的传递信息,逻辑为or_and_

tuples = ((0,'or_'),(1,'or_'),(0,'and_'))

out = tuples[0][0]
logic = tuples[0][1]
for i in range(1,len(tuples)):
    s = tuples[i][0]
    out = out | s if logic == 'or_' else out & s
    out = s
    logic = tuples[i][1]

print(out)
0

改变元组的顺序会产生另一个信号:

tuples = ((0,'or_'),(0,'and_'),(1,'or_'))

out = tuples[0][0]
logic = tuples[0][1]
for i in range(1,len(tuples)):
    s = tuples[i][0]
    out = out | s if logic == 'or_' else out & s
    out = s
    logic = tuples[i][1]

print(out)
1

性能取决于计数需要更新多少次。对所有三个时间序列使用 n=1,000,000,您的代码在我的机器上的平均运行时间为 0.6 秒,我的代码为 0.63 秒。

然后,我稍微更改了交叉逻辑以保存 if/else 的数量,这样嵌套的 if/else 仅在时间序列交叉时才被触发,只能通过一次比较来检查。这进一步将运行时间的差异减半,因此上面的代码现在比原始代码的运行时间长 2.5%。


推荐阅读