首页 > 解决方案 > 为什么 np.float32 的执行速度可能比 np.float64 慢?

问题描述

我正在根据 POT 的 repo 中的Sinkhorn-Knopp实现编写自己的优化传输问题算法的实现。github该函数如下所示:

#version for the dense matrices
def sinkhorn_knopp(C, reg, a = None, b = None, max_iter = 1e3, eps = 1e-9, log = False, verbose = False, log_interval = 10):

a = np.asarray(a, dtype=np.float64)
b = np.asarray(b, dtype=np.float64)
C = np.asarray(C, dtype=np.float64)

# if the weights are not specified, assign them uniformly
if len(a.shape) == 0:
     a = np.ones((C.shape[0],), dtype=np.float64) / C.shape[0]
if len(b.shape) == 0:
     b = np.ones((C.shape[1],), dtype=np.float64) / C.shape[1]       

# Faster exponent
K = np.divide(C, -reg)
K = np.exp(K)

# Init data
dim_a = len(a)
dim_b = len(b)

# Set ininital values of u and v
u = np.ones(dim_a) / dim_a
v = np.ones(dim_b) / dim_b

r = np.empty_like(b)
Kp = (1 / a).reshape(-1, 1) * K
err = 1
cpt = 0

if log:
    log = {'err' : []}

while(err > eps and cpt < max_iter):
    uprev = u
    vprev = v
    
    KtransposeU = K.T @ u
    v = np.divide(b, KtransposeU)
    u = 1. / (Kp @ v)
    
    if (np.any(KtransposeU == 0)
            or np.any(np.isnan(u)) or np.any(np.isnan(v))
            or np.any(np.isinf(u)) or np.any(np.isinf(v))):
        # we have reached the machine precision
        # come back to previous solution and quit loop
        print('Warning: numerical errors at iteration', cpt)
        u = uprev
        v = vprev
        break
    if cpt % log_interval == 0:
        #residual on the iteration
        r = (u @ K) * v 
        # violation of marginal
        err = np.linalg.norm(r - b)  
        
        if log:
            log['err'].append(err)
    cpt += 1
             
#return OT matrix
ot_matrix = u * K * v
loss = np.sum(C * ot_matrix)
if log:
    return ot_matrix, loss, log
else:
    return ot_matrix, loss 

我已经列出了np.float64. 然而,如果一个工作np.float32,令人惊讶的是,该算法的执行速度较慢。天真地,人们应该期望“更小”的浮点数工作得更快,因为位操作更少。但测量结果显示以下数字:

#np.float64 version
63.6 ms ± 7.94 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
#np.float32 version
71.4 ms ± 2.01 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)

#np.float64 version
650 ms ± 12.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
#np.float32 version
2.48 s ± 298 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)

对于更大的问题,时间差几乎是次,这看起来很奇怪。为什么会这样?

标签: pythonperformancenumpylinear-algebra

解决方案


推荐阅读