arrays - fortran 中 blas 的包装器
问题描述
对评论的回应
有几个人指出我试图违反 fortran 标准。情况并非如此,我希望拥有完全可移植且符合标准的代码。以下是我处理它的天真方法。显然它不起作用。因此,我不知道解决方案,但我对所需的功能非常有信心。实现此功能的方法不是必需的。我不受约束fortran
,尽管这是可取的。
原帖
我最近发现ZGEMM
,该BLAS
库为复数实数矩阵乘法提供了次优性能。ZGEMM
用两个调用替换 1 complex-complexDGEMM
可以将运行时间缩短 3 倍(而不是像天真的期望那样缩短 1.5倍)!
module Mmatmul
contains
subroutine zd_matmul_module(n,a,b,c,T1,T2)
integer,intent(in)::n
double complex,intent(in)::a(:,:)
double complex,intent(out)::c(:,:)
double precision::T1(:,:),T2(:,:),b(:,:)
! copy real part
call dcopy (n*n, a, 2, T1, 1)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
! put result on hold
call dcopy (n*n, T2, 1, C, 1)
! copy imaginary part
T1=dimag(a)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
call dcopy (n*n, C, 1, T1, 1)
! put real and imaginary parts in place
C=DCMPLX(T1,T2)
end subroutine zd_matmul_module
end module Mmatmul
这个子程序是非常透明的,但是,它不如子程序灵活BLAS
。例如,两个临时矩阵需要具有正确的类型和形状T1
。T2
这很不方便。例如考虑一种情况,我需要两个执行 3 个复数双倍乘法,结果是 rank-2、rank-3 和 rank-4 矩阵。这意味着编写 3 个专用子程序,并拥有 6 个不同的临时矩阵!
另一方面,BLAS
如果传递的数组具有正确的大小,则子例程不关心其参数的类型和形状。这对我来说是一个非常理想的财产。
因此,我想在 fortran 中为 BLAS 子例程编写一个包装器,这样它就不会进行类型和形状检查,这正是它在BLAS
. 我正在考虑使用一个外部子程序,这将帮助我绕过类型和等级 fortran 检查。我有意识地想要这样做,因为它确实是我代码的瓶颈。
我想出了这个实现
subroutine zd_matmul_external(n,a,b,c,T1,T2)
integer,intent(in)::n
double precision,intent(in)::a(*)
double precision,intent(out)::c(*)
double precision::T1(*),T2(*),b(*)
call dcopy (n*n, a(1), 2, T1, 1)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
call dcopy (n*n, T2, 1, C(1), 2)
call dcopy (n*n, a(2), 2, T1, 1)
call dgemm ('N', 'N', n, n, n, 1.0D0, T1, n, B, n, 0.0D0, T2, n)
call dcopy (n*n, T2, 1, C(2), 2)
end subroutine zd_matmul_external
它可以编译,但会给出一个运行时错误,指示内存问题。由于它可以编译,并且它的结构与BLAS
实现相似,我相信它在语法上没有错误。但问题是什么?
为了完整起见,这是主程序
program Psmall
use Mmatmul
integer, allocatable::seed(:)
integer:: sz,n
double complex, allocatable :: AZ(:,:), BZ(:,:)
double complex, allocatable :: CZ1(:,:), CZ2(:,:), CZ3(:,:)
double precision, allocatable :: BD(:,:), AR(:,:), AI(:,:)
external :: zd_matmul_external
call random_seed(size = sz) ! Finds the size sz of the seed
allocate (seed(sz))
open ( unit=1, file='/dev/urandom', access='stream', form='UNFORMATTED')
read (1) seed
close (1)
! -- prepare random arrays --
n= 500
allocate(AR(n,n),AI(n,n),BD(n,n),AZ(n,n),BZ(n,n),CZ1(n,n),CZ2(n,n),CZ3(n,n))
call random_number(AR)
call random_number(AI)
call random_number(BD)
AZ=DCMPLX(AR,AI)
BZ=DCMPLX(BD)
! -- fast, but inconvenient --
call zd_matmul_module(n, AZ, BD, CZ1, AR, AI)
! Problematic next line
call zd_matmul_external(n, AZ,BD,CZ2, AR, AI)
CZ3 = matmul(AZ, BD)
print*,sqrt(sum(abs(cz1-cz3)**2)),sqrt(sum(abs(cz2-cz3)**2))
end program Psmall
我正在使用 gfortran 编译器(gcc 版本 5.3.0)@ macOS 10.13.5。
解决方案
我会做类似下面的事情,它会阻止乘法到可管理的块中。它符合标准,显示出比 openblas(在我的笔记本电脑上)的速度显着提高,并且对于大型矩阵,使用的内存明显少于建议的实现。仅针对单线程编写和测试,我将并行化作为练习。您还可以通过调整缓冲区大小来获得更好的性能。
代码:
Module blas_interfaces_module
Use, Intrinsic :: iso_fortran_env, Only : wp => real64
Implicit None ( Type, External )
Interface
Subroutine dgemm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc )
Import wp
Implicit None ( Type, External )
Character , Intent( In ) :: transa
Character , Intent( In ) :: transb
Integer , Intent( In ) :: m
Integer , Intent( In ) :: n
Integer , Intent( In ) :: k
Real( wp ) , Intent( In ) :: alpha
Real( wp ), Dimension( * ), Intent( In ) :: a
Integer , Intent( In ) :: lda
Real( wp ), Dimension( * ), Intent( In ) :: b
Integer , Intent( In ) :: ldb
Real( wp ) , Intent( In ) :: beta
Real( wp ), Dimension( * ), Intent( Out ) :: c
Integer , Intent( In ) :: ldc
End Subroutine dgemm
Subroutine zgemm( transa, transb, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc )
Import wp
Implicit None ( Type, External )
Character , Intent( In ) :: transa
Character , Intent( In ) :: transb
Integer , Intent( In ) :: m
Integer , Intent( In ) :: n
Integer , Intent( In ) :: k
Complex( wp ) , Intent( In ) :: alpha
Complex( wp ), Dimension( * ), Intent( In ) :: a
Integer , Intent( In ) :: lda
Complex( wp ), Dimension( * ), Intent( In ) :: b
Integer , Intent( In ) :: ldb
Complex( wp ) , Intent( In ) :: beta
Complex( wp ), Dimension( * ), Intent( Out ) :: c
Integer , Intent( In ) :: ldc
End Subroutine zgemm
End Interface
Public :: dgemm
Public :: zgemm
Private
End Module blas_interfaces_module
Module mm_complex_real_module
Implicit None ( Type, External )
Public :: mm_complex_real
Private
Contains
Subroutine mm_complex_real( n, a, b, c )
Use, Intrinsic :: iso_fortran_env, Only : wp => real64
Use blas_interfaces_module, Only : dgemm
Implicit None ( Type, External )
Integer , Intent( In ) :: n
Complex( wp ), Dimension( 1:n, 1:n ), Intent( In ) :: a
Real ( wp ), Dimension( 1:n, 1:n ), Intent( In ) :: b
Complex( wp ), Dimension( 1:n, 1:n ), Intent( Out ) :: c
Integer, Parameter :: n_buff_max = 2048
Real( wp ), Dimension( :, : ), Allocatable :: a_real, a_imag
Real( wp ), Dimension( :, : ), Allocatable :: c_real, c_imag
Integer :: i_start, j_start, k_start
Integer :: i_end , j_end , k_end
Integer :: i_len , j_len , k_len
Integer :: n_buff
Integer :: i, j, k
n_buff = Min( n_buff_max, n )
Allocate( a_real( 1:n_buff, 1:n_buff ) )
Allocate( a_imag( 1:n_buff, 1:n_buff ) )
Allocate( c_real( 1:n_buff, 1:n_buff ) )
Allocate( c_imag( 1:n_buff, 1:n_buff ) )
c = 0.0_wp
j_start = 1
Do While( j_start < n )
j_end = Min( j_start + n_buff - 1, n )
j_len = j_end - j_start + 1
i_start = 1
Do While( i_start < n )
i_end = Min( i_start + n_buff - 1, n )
i_len = i_end - i_start + 1
Do j = 1, j_len
Do i = 1, i_len
a_real( i, j ) = Real ( a( i + i_start - 1, j + j_start - 1 ), wp )
a_imag( i, j ) = Aimag( a( i + i_start - 1, j + j_start - 1 ) )
End Do
End Do
k_start = 1
Do While( k_start < n )
k_end = Min( k_start + n_buff - 1, n )
k_len = k_end - k_start + 1
Call dgemm( 'N', 'N', i_len, k_len, j_len, 1.0_wp, a_real , Size( a_real, Dim = 1 ), &
b ( j_start, k_start ), Size( b , Dim = 1 ), &
0.0_wp, c_real , Size( c_real, Dim = 1 ) )
Call dgemm( 'N', 'N', i_len, k_len, j_len, 1.0_wp, a_imag , Size( a_imag, Dim = 1 ), &
b ( j_start, k_start ), Size( b , Dim = 1 ), &
0.0_wp, c_imag , Size( c_imag, Dim = 1 ) )
Do k = k_start, k_end
Do i = i_start, i_end
c( i, k ) = c( i, k ) + &
Cmplx( c_real( i - i_start + 1, k - k_start + 1 ), &
c_imag( i - i_start + 1, k - k_start + 1 ), Kind = wp )
End Do
End Do
k_start = k_start + n_buff
End Do
i_start = i_start + n_buff
End Do
j_start = j_start + n_buff
End Do
End Subroutine mm_complex_real
End Module mm_complex_real_module
Program testit
Use, Intrinsic :: iso_fortran_env, Only : wp => real64, li => int64, stdout => output_unit
Use mm_complex_real_module, Only : mm_complex_real
Use blas_interfaces_module, Only : dgemm, zgemm
Implicit None ( Type, External )
Complex( wp ), Dimension( :, : ), Allocatable :: a
Complex( wp ), Dimension( :, : ), Allocatable :: c_ref, c_blas, c_mine
Complex( wp ), Dimension( :, : ), Allocatable :: b_complex
Real( wp ), Dimension( :, : ), Allocatable :: b
Real( wp ), Dimension( :, : ), Allocatable :: a_real, a_imag
Real( wp ) :: t_ref, t_blas, t_mine
Integer( li ) :: start, finish, rate
Integer :: n
Do n = 500, 8000, 500
Allocate( a ( 1:n, 1:n ) )
Allocate( b ( 1:n, 1:n ) )
Allocate( c_ref ( 1:n, 1:n ) )
Allocate( c_blas( 1:n, 1:n ) )
Allocate( c_mine( 1:n, 1:n ) )
Allocate( a_real( 1:n, 1:n ) )
Allocate( a_imag( 1:n, 1:n ) )
Call Random_number( a_real )
Call Random_number( a_imag )
a = Cmplx( a_real, a_imag, Kind = wp )
Deallocate( a_imag )
Deallocate( a_real )
Call Random_number( b )
Call system_clock( start, rate )
c_ref = Matmul( a, b )
Call system_clock( finish, rate )
t_ref = Real( finish - start, Kind = Kind( t_ref ) ) / rate
Allocate( b_complex( 1:n, 1:n ) )
Call system_clock( start, rate )
b_complex = b
Call zgemm( 'N', 'N', n, n, n, ( 1.0_wp, 0.0_wp ), a , Size( a , Dim = 1 ), &
b_complex, Size( b_complex, Dim = 1 ), &
( 0.0_wp, 0.0_wp ), c_blas , Size( c_blas , Dim = 1 ) )
Call system_clock( finish, rate )
t_blas = Real( finish - start, Kind = Kind( t_ref ) ) / rate
Deallocate( b_complex )
Call system_clock( start, rate )
Call mm_complex_real( n, a, b, c_mine )
Call system_clock( finish, rate )
t_mine = Real( finish - start, Kind = Kind( t_ref ) ) / rate
Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4 )' ) 'Reference', n, t_ref
Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4, t50, "Error = ", g20.10 )' ) &
'BLAS', n, t_blas, Maxval( Abs( c_ref - c_blas ) )
Write( stdout, '( a, t12, "Rank = ", i5, t30, "Time = ", f9.4, t50, "Error = ", g20.10 )' ) &
'Mine', n, t_mine, Maxval( Abs( c_ref - c_mine ) )
Write( stdout, * )
Deallocate( c_mine )
Deallocate( c_blas )
Deallocate( c_ref )
Deallocate( b )
Deallocate( a )
End Do
End Program testit
使用调试标志和编译器版本对小矩阵进行测试:
ijb@ijb-Latitude-5410:~/work/stack$ gfortran --version
GNU Fortran (Ubuntu 9.3.0-17ubuntu1~20.04) 9.3.0
Copyright (C) 2019 Free Software Foundation, Inc.
This is free software; see the source for copying conditions. There is NO
warranty; not even for MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -std=f2018 -fcheck=all -finit-real=snan -fexternal-blas -Wuse-without-only mm.f90 -g -lopenblas
ijb@ijb-Latitude-5410:~/work/stack$ export OMP_NUM_THREADS=1
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
Reference Rank = 500 Time = 0.0317
BLAS Rank = 500 Time = 0.0244 Error = 0.000000000
Mine Rank = 500 Time = 0.0221 Error = 0.2049518614E-12
Reference Rank = 1000 Time = 0.1776
BLAS Rank = 1000 Time = 0.1810 Error = 0.000000000
Mine Rank = 1000 Time = 0.1260 Error = 0.3215549355E-12
Reference Rank = 1500 Time = 0.5452
BLAS Rank = 1500 Time = 0.5370 Error = 0.000000000
Mine Rank = 1500 Time = 0.3571 Error = 0.4687428402E-12
Reference Rank = 2000 Time = 1.1895
BLAS Rank = 2000 Time = 1.1421 Error = 0.000000000
Mine Rank = 2000 Time = 0.7319 Error = 0.5084229946E-12
Reference Rank = 2500 Time = 2.1990
BLAS Rank = 2500 Time = 2.1870 Error = 0.000000000
Mine Rank = 2500 Time = 1.6239 Error = 0.7279509461E-12
Reference Rank = 3000 Time = 4.3954
BLAS Rank = 3000 Time = 4.0730 Error = 0.000000000
Mine Rank = 3000 Time = 2.6654 Error = 0.8276526716E-12
Reference Rank = 3500 Time = 6.1562
BLAS Rank = 3500 Time = 6.4210 Error = 0.000000000
Mine Rank = 3500 Time = 4.5101 Error = 0.1016845989E-11
Reference Rank = 4000 Time = 9.4788
BLAS Rank = 4000 Time = 9.7646 Error = 0.000000000
Mine Rank = 4000 Time = 6.1978 Error = 0.1023181539E-11
ijb@ijb-Latitude-5410:~/work/stack$
在打开编译器优化的情况下对大型矩阵进行测试:
ijb@ijb-Latitude-5410:~/work/stack$ gfortran -Wall -Wextra -std=f2018 -O3 -fexternal-blas -Wuse-without-only mm.f90 -lopenblas
ijb@ijb-Latitude-5410:~/work/stack$ export OMP_NUM_THREADS=1
ijb@ijb-Latitude-5410:~/work/stack$ ./a.out
Reference Rank = 500 Time = 0.0324
BLAS Rank = 500 Time = 0.0244 Error = 0.000000000
Mine Rank = 500 Time = 0.0174 Error = 0.1819877365E-12
Reference Rank = 1000 Time = 0.1849
BLAS Rank = 1000 Time = 0.1741 Error = 0.000000000
Mine Rank = 1000 Time = 0.1075 Error = 0.3215549355E-12
Reference Rank = 1500 Time = 0.5684
BLAS Rank = 1500 Time = 0.5422 Error = 0.000000000
Mine Rank = 1500 Time = 0.3051 Error = 0.4582862942E-12
Reference Rank = 2000 Time = 1.1869
BLAS Rank = 2000 Time = 1.1340 Error = 0.000000000
Mine Rank = 2000 Time = 0.6584 Error = 0.5796914040E-12
Reference Rank = 2500 Time = 2.1756
BLAS Rank = 2500 Time = 2.2456 Error = 0.000000000
Mine Rank = 2500 Time = 1.5541 Error = 0.7279509461E-12
Reference Rank = 3000 Time = 4.3817
BLAS Rank = 3000 Time = 4.1084 Error = 0.000000000
Mine Rank = 3000 Time = 2.6164 Error = 0.8198074455E-12
Reference Rank = 3500 Time = 6.0470
BLAS Rank = 3500 Time = 6.2369 Error = 0.000000000
Mine Rank = 3500 Time = 4.0864 Error = 0.9713407673E-12
Reference Rank = 4000 Time = 9.6197
BLAS Rank = 4000 Time = 9.5614 Error = 0.000000000
Mine Rank = 4000 Time = 6.0322 Error = 0.1048140855E-11
Reference Rank = 4500 Time = 13.2843
BLAS Rank = 4500 Time = 13.9924 Error = 0.000000000
Mine Rank = 4500 Time = 8.5252 Error = 0.1325804964E-11
Reference Rank = 5000 Time = 19.0425
BLAS Rank = 5000 Time = 18.5975 Error = 0.000000000
Mine Rank = 5000 Time = 11.7193 Error = 0.1525268984E-11
Reference Rank = 5500 Time = 25.1735
BLAS Rank = 5500 Time = 25.2269 Error = 0.000000000
Mine Rank = 5500 Time = 15.5540 Error = 0.1607774678E-11
Reference Rank = 6000 Time = 32.5070
BLAS Rank = 6000 Time = 35.5822 Error = 0.000000000
Mine Rank = 6000 Time = 20.9217 Error = 0.1775845175E-11
Reference Rank = 6500 Time = 42.5040
BLAS Rank = 6500 Time = 41.4937 Error = 0.000000000
Mine Rank = 6500 Time = 25.5950 Error = 0.1942681535E-11
Reference Rank = 7000 Time = 51.9414
BLAS Rank = 7000 Time = 52.3054 Error = 0.000000000
Mine Rank = 7000 Time = 31.4987 Error = 0.2033691978E-11
Reference Rank = 7500 Time = 63.9316
BLAS Rank = 7500 Time = 63.9044 Error = 0.000000000
Mine Rank = 7500 Time = 38.7648 Error = 0.2250884549E-11
Reference Rank = 8000 Time = 77.2484
BLAS Rank = 8000 Time = 78.0691 Error = 0.000000000
Mine Rank = 8000 Time = 46.3565 Error = 0.2273736754E-11
ijb@ijb-Latitude-5410:~/work/stack$
推荐阅读
- java - Tomcat 安装目录无效。它缺少预期的文件或文件夹 tcruntime-ctl.sh
- java - NetworkBoundResource 类中 MediatorLiveData 的使用
- django - 将自定义 url 操作参数添加到 django-cms
- python - 框架内的 Python 网格管理器
- cakephp - CakePHP 3:[ManyToMany] 使用多个特定标签获取书签
- python - 将同一天的所有数据收集到一行
- listview - 如何在主详细信息页面中添加导航页面?
- python - 创建具有 REAL 值(包括 UNIQUE 约束)的 SQLITE 表的最佳方法是什么?
- php - 未识别的索引错误,数据库中没有结果
- php - 从访问器返回 HTML 是否可以接受?