首页 > 解决方案 > 提高稀疏张量收缩的性能

问题描述

我最近一直在研究张量收缩代码,我遇到了非常缓慢的性能时间。要理解的主要事情是考虑的张量是“稀疏的”,因此它的大部分条目都是零。此外,对于所考虑的数组,这是一个 dtype=numpy.complex128。我的张量是 4D + 1D,其中 1D 代表 complex 类型的数组,数组大小为 N~1000。为了避免出现内存问题,我使用“键分配”方法实现了代码,也就是说,张量的每个元素都有一个关联的键( 1,2,3,4 ),同时,索引 1, 2,3,4 是“多值”索引,可以采用三个不同的值(也可以是整数)。例如,非零条目的一个可能键可能是:

                             key = ( 0,1,2,3,4,5,6,7,8,9,10,11 )

所有键都是唯一的,但由于我正在考虑的问题,键值之间的关系保持不变,因此,将所有整数包含在单个键组合中很重要。这意味着可以通过对一对键施加条件来进一步减少元素的数量,例如:

            np.abs( key_1[0] - key_[2] ) <=3 ---> non-zero entries must satisfy this condition (AS AN EXAMPLE, NOT THE ACTUAL CASE)

对于每个键,都关联一个长度为 N 的 Numpy 数组。我的代码需要执行以下求和来收缩张量:

                     A_{ 1,2,3,4 } = sum_{5,6,7,8} B_{ 1,2,7,5 } C_{ 5,6,7,8 } D_{8,6,3,4} 

对于指定的键的单个值,其中 B、C、D 是长度为 N~1000 的 numpy 数组。

记住每个索引 1,2,.. 都是多值的,所以实际上是一个多索引。我实现了以下代码,例如计算左侧 A 的键 (1,2,3,4) 元素:

def contraction_func( data_object, key=(1,2,3,4) ):
   
    A_keys  = data_object.get_left_keys()       # A 2D numpy array containing the keys by rows,i.e. 
                                                # gamma_p_keys = np.array( [ [1,2,3,4,... ],
                                                #                             [4,4,6,7,... ], ...etc
                                                #                            ]  )
                                                # where a single row has 12 integers in my example

    A_array = data_object.get_left_array()   # to each row of left_keys, corresponds a row here
                                                # with N elements along axis=1

    C_dict   = data_object.get_C_dict()       # this gives the object C defined above as a dictionary
                                              # with the keys given by combinations of (5,6,7,8)        

    # since for example, key =(1,2,3,4) --> np.array([ 0, -1, 2, 5, 6, 6, 7,9, -1, 0, 2, -2 ])
    multi_index_1p = key[0:3]
    multi_index_2p = key[3:6]
    multi_index_1  = key[6:9]
    multi_index_2  = key[9:]  

    total_indices  = len( multi_index_1p )        
    aux            = np.concatenate( ( multi_index_1, multi_index_2 )  )   
         
    a = ( A_keys[ :, 2*total_indices:3*total_indices ] == multi_index_1  ).all( axis=1 ) 
    b = ( A_keys[ :, 3*total_indices: ] == multi_index_2  ).all( axis=1 )             
 
    mask = np.multiply( (a,b ) )
    
    A_keys_filtered_right  = A_keys[ mask ]  
    A_array_filtered_right = A_array [ mask ]  

    a = ( A_keys[ :, 0:total_indices ] == multi_index_1p  ).all( axis=1 ) 
    b = ( A_keys[ :, total_indices:2*total_indices ] == multi_index_2p ).all( axis=1 )             
 
    mask = np.multiply( (a,b ) )

    A_keys_filtered_left  = A_keys[ mask ]  
    A_array_filtered_left = A_array [ mask ] 
     
    # the very time consuming part I guess..

    keys_product = np.array( [ np.concatenate( (x[0], x[1] ) ) for x in product( 
                                                                  A_keys_filtered_left  , 
                                                                  A_keys_filtered_right   )                                                                         
                                                                                                       
                                ], dtype=np.int32  )
    
    product_array = np.array( [ np.multiply( x[0], x[1] ) for x in product( 
                                                                   A_array_filtered_left , 
                                                                   A_array_filtered_right ) ] )
        

    
    s_keys = np.concatenate( ( keys_product[ :, 6:9 ], keys_product[:, 12:15] ), axis=1)
    g_keys = np.concatenate( ( keys_product[ :, 9:12 ], keys_product[:,15:18 ], axis=1 )    
    s_g_keys = list( map( tuple, np.concatenate( (g_keys,s_keys), axis = 1 ) ) )

    aux_array           = np.array( [ C_dict[ x ] for x in s_g_keys ] ) 
    aux                 = np.multiply( aux_array, product_array )

    del conv_array, gamma_product_array   

    solution_conv       = aux.sum( axis=0 )    


    return solution_conv

随着复杂性的增加,问题的扩展速度非常快,即,当表示非零条目的键的总数要大得多时。对于(1,2,3,4)的通常〜10 ^ 5个键的大小,我有兴趣这样做。我发现对于 (1,2,3,4) 的单个值,太多的项对收缩的 rhs 有贡献。但我的猜测是查找密钥非常昂贵,关于如何改进这一点的任何想法?

非常感谢!

标签: pythonperformancetensor

解决方案


推荐阅读