首页 > 解决方案 > fortran 中 blas 的包装器

问题描述

对评论的回应

有几个人指出我试图违反 fortran 标准。情况并非如此,我希望拥有完全可移植且符合标准的代码。以下是我处理它的天真方法。显然它不起作用。因此,我不知道解决方案,但我对所需的功能非常有信心。实现此功能的方法不是必需的。我不受约束fortran,尽管这是可取的。

原帖

我最近发现ZGEMM,该BLAS库为复数实数矩阵乘法提供了次优性能。ZGEMM用两个调用替换 1 complex-complexDGEMM可以将运行时间缩短 3 倍(而不是像天真的期望那样缩短 1.5倍)!

module Mmatmul
contains
  subroutine zd_matmul_module(n,a,b,c,T1,T2)
    integer,intent(in)::n
    double complex,intent(in)::a(:,:)
    double complex,intent(out)::c(:,:)
    double precision::T1(:,:),T2(:,:),b(:,:)

    ! copy real part 
    call dcopy (n*n, a, 2, T1, 1)
    call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
    ! put result on hold
    call dcopy (n*n, T2, 1, C, 1)

    ! copy imaginary part
    T1=dimag(a)
    call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
    call dcopy (n*n, C, 1, T1, 1)

    ! put real and imaginary parts in place
    C=DCMPLX(T1,T2)
  end subroutine zd_matmul_module
end module Mmatmul

这个子程序是非常透明的,但是,它不如子程序灵活BLAS。例如,两个临时矩阵需要具有正确的类型和形状T1T2这很不方便。例如考虑一种情况,我需要两个执行 3 个复数双倍乘法,结果是 rank-2、rank-3 和 rank-4 矩阵。这意味着编写 3 个专用子程序,并拥有 6 个不同的临时矩阵!

另一方面,BLAS如果传递的数组具有正确的大小,则子例程不关心其参数的类型和形状。这对我来说是一个非常理想的财产。

因此,我想在 fortran 中为 BLAS 子例程编写一个包装器,这样它就不会进行类型和形状检查,这正是它在BLAS. 我正在考虑使用一个外部子程序,这将帮助我绕过类型和等级 fortran 检查。我有意识地想要这样做,因为它确实是我代码的瓶颈

我想出了这个实现

subroutine zd_matmul_external(n,a,b,c,T1,T2)
integer,intent(in)::n
double precision,intent(in)::a(*)
double precision,intent(out)::c(*)
double precision::T1(*),T2(*),b(*)

call dcopy (n*n, a(1), 2, T1, 1)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
call dcopy (n*n, T2, 1, C(1), 2)

call dcopy (n*n, a(2), 2, T1, 1)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
call dcopy (n*n, T2, 1, C(2), 2)
end subroutine zd_matmul_external

它可以编译,但会给出一个运行时错误,指示内存问题。由于它可以编译,并且它的结构与BLAS实现相似,我相信它在语法上没有错误。但问题是什么?

为了完整起见,这是主程序

 program Psmall
  use Mmatmul
  integer, allocatable::seed(:)
  integer:: sz,n
  double complex, allocatable :: AZ(:,:), BZ(:,:)
  double complex, allocatable :: CZ1(:,:), CZ2(:,:), CZ3(:,:)
  double precision, allocatable :: BD(:,:), AR(:,:), AI(:,:)

  external :: zd_matmul_external
  
  call  random_seed(size = sz)  ! Finds the size sz of the seed
  allocate (seed(sz))
  open ( unit=1, file='/dev/urandom', access='stream', form='UNFORMATTED')
  read (1) seed
  close (1)

  ! -- prepare random arrays --
  n= 500
  allocate(AR(n,n),AI(n,n),BD(n,n),AZ(n,n),BZ(n,n),CZ1(n,n),CZ2(n,n),CZ3(n,n))
  call random_number(AR)
  call random_number(AI)
  call random_number(BD)
  AZ=DCMPLX(AR,AI)
  BZ=DCMPLX(BD)
  
  ! -- fast, but inconvenient -- 
  call zd_matmul_module(n, AZ, BD, CZ1, AR, AI)

  ! Problematic next line
  call zd_matmul_external(n, AZ,BD,CZ2, AR, AI)

  CZ3 = matmul(AZ, BD) 

  print*,sqrt(sum(abs(cz1-cz3)**2)),sqrt(sum(abs(cz2-cz3)**2))
end program Psmall

我正在使用 gfortran 编译器(gcc 版本 5.3.0)@ macOS 10.13.5。

标签: arraysfortranblas

解决方案


我会做类似下面的事情,它会阻止乘法到可管理的块中。它符合标准,显示出比 openblas(在我的笔记本电脑上)的速度显着提高,并且对于大型矩阵,使用的内存明显少于建议的实现。仅针对单线程编写和测试,我将并行化作为练习。您还可以通过调整缓冲区大小来获得更好的性能。

代码:

Module blas_interfaces_module

  Use, Intrinsic :: iso_fortran_env, Only : wp => real64

  Implicit None ( Type, External )

  Interface

     Subroutine dgemm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc )
       Import wp
       Implicit None ( Type, External )
       Character                 , Intent( In    ) :: transa
       Character                 , Intent( In    ) :: transb
       Integer                   , Intent( In    ) :: m
       Integer                   , Intent( In    ) :: n
       Integer                   , Intent( In    ) :: k
       Real( wp )                , Intent( In    ) :: alpha
       Real( wp ), Dimension( * ), Intent( In    ) :: a
       Integer                   , Intent( In    ) :: lda
       Real( wp ), Dimension( * ), Intent( In    ) :: b
       Integer                   , Intent( In    ) :: ldb
       Real( wp )                , Intent( In    ) :: beta
       Real( wp ), Dimension( * ), Intent(   Out ) :: c
       Integer                   , Intent( In    ) :: ldc
     End Subroutine dgemm

       Subroutine zgemm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc )
       Import wp
       Implicit None ( Type, External )
       Character                    , Intent( In    ) :: transa
       Character                    , Intent( In    ) :: transb
       Integer                      , Intent( In    ) :: m
       Integer                      , Intent( In    ) :: n
       Integer                      , Intent( In    ) :: k
       Complex( wp )                , Intent( In    ) :: alpha
       Complex( wp ), Dimension( * ), Intent( In    ) :: a
       Integer                      , Intent( In    ) :: lda
       Complex( wp ), Dimension( * ), Intent( In    ) :: b
       Integer                      , Intent( In    ) :: ldb
       Complex( wp )                , Intent( In    ) :: beta
       Complex( wp ), Dimension( * ), Intent(   Out ) :: c
       Integer                      , Intent( In    ) :: ldc
     End Subroutine zgemm

  End Interface

  Public :: dgemm
  Public :: zgemm

  Private
  
End Module blas_interfaces_module

Module mm_complex_real_module

  Implicit None ( Type, External )

  Public :: mm_complex_real

  Private
  
Contains

  Subroutine mm_complex_real( n, a, b, c )

    Use, Intrinsic :: iso_fortran_env, Only : wp => real64

    Use blas_interfaces_module, Only : dgemm
    
    Implicit None ( Type, External )

    Integer                             , Intent( In    ) :: n
    Complex( wp ), Dimension( 1:n, 1:n ), Intent( In    ) :: a
    Real   ( wp ), Dimension( 1:n, 1:n ), Intent( In    ) :: b
    Complex( wp ), Dimension( 1:n, 1:n ), Intent(   Out ) :: c
    
    Integer, Parameter :: n_buff_max = 2048

    Real( wp ), Dimension( :, : ), Allocatable :: a_real, a_imag
    Real( wp ), Dimension( :, : ), Allocatable :: c_real, c_imag

    Integer :: i_start, j_start, k_start
    Integer :: i_end  , j_end  , k_end
    Integer :: i_len  , j_len  , k_len
    Integer :: n_buff
    Integer :: i, j, k

    n_buff = Min( n_buff_max, n )
    Allocate( a_real( 1:n_buff, 1:n_buff ) )
    Allocate( a_imag( 1:n_buff, 1:n_buff ) )
    Allocate( c_real( 1:n_buff, 1:n_buff ) )
    Allocate( c_imag( 1:n_buff, 1:n_buff ) )

    c = 0.0_wp
    
    j_start = 1
    Do While( j_start < n )

       j_end = Min( j_start + n_buff - 1, n )
       j_len = j_end - j_start + 1

       i_start = 1
       Do While( i_start < n )
          
          i_end = Min( i_start + n_buff - 1, n )
          i_len = i_end - i_start + 1

          Do j = 1, j_len
             Do i = 1, i_len
                a_real( i, j ) = Real ( a( i + i_start - 1, j + j_start - 1 ), wp )
                a_imag( i, j ) = Aimag( a( i + i_start - 1, j + j_start - 1 )     )
             End Do
          End Do

          k_start = 1
          Do While( k_start < n )
             
             k_end = Min( k_start + n_buff - 1, n )
             k_len = k_end - k_start + 1
          
             Call dgemm( 'N', 'N', i_len, k_len, j_len, 1.0_wp, a_real                    , Size( a_real, Dim = 1 ), &
                                                                b     ( j_start, k_start ), Size( b     , Dim = 1 ), &
                                                        0.0_wp, c_real                    , Size( c_real, Dim = 1 ) )
             Call dgemm( 'N', 'N', i_len, k_len, j_len, 1.0_wp, a_imag                    , Size( a_imag, Dim = 1 ), &
                                                                b     ( j_start, k_start ), Size( b     , Dim = 1 ), &
                                                        0.0_wp, c_imag                    , Size( c_imag, Dim = 1 ) )

             Do k = k_start, k_end
                Do i = i_start, i_end
                   c( i, k ) = c( i, k ) + &
                        Cmplx( c_real( i - i_start + 1, k - k_start + 1 ), &
                               c_imag( i - i_start + 1, k - k_start + 1 ), Kind = wp )
                End Do
             End Do

             k_start = k_start + n_buff
          End Do
             
          i_start = i_start + n_buff
       End Do

       j_start = j_start + n_buff
    End Do

  End Subroutine mm_complex_real
  
End Module mm_complex_real_module

Program testit

  Use, Intrinsic :: iso_fortran_env, Only : wp => real64, li => int64, stdout => output_unit

  Use mm_complex_real_module, Only : mm_complex_real
  Use blas_interfaces_module, Only : dgemm, zgemm

  Implicit None ( Type, External )

  Complex( wp ), Dimension( :, : ), Allocatable :: a
  Complex( wp ), Dimension( :, : ), Allocatable :: c_ref, c_blas, c_mine
  Complex( wp ), Dimension( :, : ), Allocatable :: b_complex

  Real( wp ), Dimension( :, : ), Allocatable :: b
  Real( wp ), Dimension( :, : ), Allocatable :: a_real, a_imag

  Real( wp ) :: t_ref, t_blas, t_mine
  
  Integer( li ) :: start, finish, rate
  
  Integer :: n

  Do n = 500, 8000, 500

     Allocate( a     ( 1:n, 1:n ) )
     Allocate( b     ( 1:n, 1:n ) )
     Allocate( c_ref ( 1:n, 1:n ) )
     Allocate( c_blas( 1:n, 1:n ) )
     Allocate( c_mine( 1:n, 1:n ) )

     Allocate( a_real( 1:n, 1:n ) )
     Allocate( a_imag( 1:n, 1:n ) )
     Call Random_number( a_real )
     Call Random_number( a_imag )
     a = Cmplx( a_real, a_imag, Kind = wp )
     Deallocate( a_imag )
     Deallocate( a_real )

     Call Random_number( b )

     Call system_clock( start, rate )
     c_ref = Matmul( a, b )
     Call system_clock( finish, rate )
     t_ref = Real( finish - start, Kind = Kind( t_ref )  ) / rate

     Allocate( b_complex( 1:n, 1:n ) )
     Call system_clock( start, rate )
     b_complex = b
     Call zgemm( 'N', 'N', n, n, n, ( 1.0_wp, 0.0_wp ), a        , Size( a        , Dim = 1 ), &
                                                        b_complex, Size( b_complex, Dim = 1 ), &
                                    ( 0.0_wp, 0.0_wp ), c_blas   , Size( c_blas   , Dim = 1 ) )
     Call system_clock( finish, rate )
     t_blas = Real( finish - start, Kind = Kind( t_ref )  ) / rate
     Deallocate( b_complex )

     Call system_clock( start, rate )
     Call mm_complex_real( n, a, b, c_mine )
     Call system_clock( finish, rate )
     t_mine = Real( finish - start, Kind = Kind( t_ref )  ) / rate

     Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4 )' ) 'Reference', n, t_ref
     Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4, t50, "Error = ", g20.10 )' ) &
          'BLAS', n, t_blas, Maxval( Abs( c_ref - c_blas ) )
     Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4, t50, "Error = ", g20.10 )' ) &
          'Mine', n, t_mine, Maxval( Abs( c_ref - c_mine ) )
     Write( stdout, * )

     Deallocate( c_mine )
     Deallocate( c_blas )
     Deallocate( c_ref  )
     Deallocate( b      )
     Deallocate( a      )
     
  End Do
  
End Program testit

使用调试标志和编译器版本对小矩阵进行测试:

ijb@ijb-Latitude-5410:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions.  There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.

ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -std=f2018 -fcheck=all -finit-real=snan -fexternal-blas -Wuse-without-only   mm.f90 -g -lopenblas
ijb@ijb-Latitude-5410:~/work/stack$ export OMP_NUM_THREADS=1
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
Reference  Rank =   500      Time =    0.0317
BLAS       Rank =   500      Time =    0.0244    Error =      0.000000000    
Mine       Rank =   500      Time =    0.0221    Error =     0.2049518614E-12

Reference  Rank =  1000      Time =    0.1776
BLAS       Rank =  1000      Time =    0.1810    Error =      0.000000000    
Mine       Rank =  1000      Time =    0.1260    Error =     0.3215549355E-12

Reference  Rank =  1500      Time =    0.5452
BLAS       Rank =  1500      Time =    0.5370    Error =      0.000000000    
Mine       Rank =  1500      Time =    0.3571    Error =     0.4687428402E-12

Reference  Rank =  2000      Time =    1.1895
BLAS       Rank =  2000      Time =    1.1421    Error =      0.000000000    
Mine       Rank =  2000      Time =    0.7319    Error =     0.5084229946E-12

Reference  Rank =  2500      Time =    2.1990
BLAS       Rank =  2500      Time =    2.1870    Error =      0.000000000    
Mine       Rank =  2500      Time =    1.6239    Error =     0.7279509461E-12

Reference  Rank =  3000      Time =    4.3954
BLAS       Rank =  3000      Time =    4.0730    Error =      0.000000000    
Mine       Rank =  3000      Time =    2.6654    Error =     0.8276526716E-12

Reference  Rank =  3500      Time =    6.1562
BLAS       Rank =  3500      Time =    6.4210    Error =      0.000000000    
Mine       Rank =  3500      Time =    4.5101    Error =     0.1016845989E-11

Reference  Rank =  4000      Time =    9.4788
BLAS       Rank =  4000      Time =    9.7646    Error =      0.000000000    
Mine       Rank =  4000      Time =    6.1978    Error =     0.1023181539E-11

ijb@ijb-Latitude-5410:~/work/stack$ 

在打开编译器优化的情况下对大型矩阵进行测试:

ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -std=f2018 -O3 -fexternal-blas  -Wuse-without-only   mm.f90 -lopenblas
ijb@ijb-Latitude-5410:~/work/stack$ export OMP_NUM_THREADS=1
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
Reference  Rank =   500      Time =    0.0324
BLAS       Rank =   500      Time =    0.0244    Error =      0.000000000    
Mine       Rank =   500      Time =    0.0174    Error =     0.1819877365E-12

Reference  Rank =  1000      Time =    0.1849
BLAS       Rank =  1000      Time =    0.1741    Error =      0.000000000    
Mine       Rank =  1000      Time =    0.1075    Error =     0.3215549355E-12

Reference  Rank =  1500      Time =    0.5684
BLAS       Rank =  1500      Time =    0.5422    Error =      0.000000000    
Mine       Rank =  1500      Time =    0.3051    Error =     0.4582862942E-12

Reference  Rank =  2000      Time =    1.1869
BLAS       Rank =  2000      Time =    1.1340    Error =      0.000000000    
Mine       Rank =  2000      Time =    0.6584    Error =     0.5796914040E-12

Reference  Rank =  2500      Time =    2.1756
BLAS       Rank =  2500      Time =    2.2456    Error =      0.000000000    
Mine       Rank =  2500      Time =    1.5541    Error =     0.7279509461E-12

Reference  Rank =  3000      Time =    4.3817
BLAS       Rank =  3000      Time =    4.1084    Error =      0.000000000    
Mine       Rank =  3000      Time =    2.6164    Error =     0.8198074455E-12

Reference  Rank =  3500      Time =    6.0470
BLAS       Rank =  3500      Time =    6.2369    Error =      0.000000000    
Mine       Rank =  3500      Time =    4.0864    Error =     0.9713407673E-12

Reference  Rank =  4000      Time =    9.6197
BLAS       Rank =  4000      Time =    9.5614    Error =      0.000000000    
Mine       Rank =  4000      Time =    6.0322    Error =     0.1048140855E-11

Reference  Rank =  4500      Time =   13.2843
BLAS       Rank =  4500      Time =   13.9924    Error =      0.000000000    
Mine       Rank =  4500      Time =    8.5252    Error =     0.1325804964E-11

Reference  Rank =  5000      Time =   19.0425
BLAS       Rank =  5000      Time =   18.5975    Error =      0.000000000    
Mine       Rank =  5000      Time =   11.7193    Error =     0.1525268984E-11

Reference  Rank =  5500      Time =   25.1735
BLAS       Rank =  5500      Time =   25.2269    Error =      0.000000000    
Mine       Rank =  5500      Time =   15.5540    Error =     0.1607774678E-11

Reference  Rank =  6000      Time =   32.5070
BLAS       Rank =  6000      Time =   35.5822    Error =      0.000000000    
Mine       Rank =  6000      Time =   20.9217    Error =     0.1775845175E-11

Reference  Rank =  6500      Time =   42.5040
BLAS       Rank =  6500      Time =   41.4937    Error =      0.000000000    
Mine       Rank =  6500      Time =   25.5950    Error =     0.1942681535E-11

Reference  Rank =  7000      Time =   51.9414
BLAS       Rank =  7000      Time =   52.3054    Error =      0.000000000    
Mine       Rank =  7000      Time =   31.4987    Error =     0.2033691978E-11

Reference  Rank =  7500      Time =   63.9316
BLAS       Rank =  7500      Time =   63.9044    Error =      0.000000000    
Mine       Rank =  7500      Time =   38.7648    Error =     0.2250884549E-11

Reference  Rank =  8000      Time =   77.2484
BLAS       Rank =  8000      Time =   78.0691    Error =      0.000000000    
Mine       Rank =  8000      Time =   46.3565    Error =     0.2273736754E-11

ijb@ijb-Latitude-5410:~/work/stack$ 

推荐阅读