c++ - 从 C++ lambda 函数创建 MPI 约简运算符
问题描述
我想编写一个包装 MPI_Allreduce 的函数,它接受任何二元运算符(如 std::reduce)作为 MPI 的归约运算符。特别是,这种函数的用户可以使用 lambda。
以下简单的示例代码说明了这一点:
#include <mpi.h>
#include <iostream>
#include <functional>
template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {
auto lambda=[op](void *a, void *b, int *len, MPI_Datatype *){
double *aa=static_cast<double *>(a);
double *bb=static_cast<double *>(bb);
for (int i=0; i<*len; ++i) {
bb[i]=op(aa[i], bb[i]);
}
};
// MPI_User_function is a typedef to: void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * )
MPI_User_function *opPtr=/* black magic code that get the function pointer from the lambda */;
MPI_Op mpiOp;
MPI_Op_create(*opPtr, 1, &mpiOp);
MPI_Allreduce(MPI_IN_PLACE, data, len, MPI_DOUBLE, mpiOp, MPI_COMM_WORLD);
MPI_Op_free(&mpiOp);
}
int main() {
MPI_Init(nullptr, nullptr);
double data[4]={1.,2.,3.,4.};
reduce(data, 4, [](double a, double b){return a+b;});
int pRank;
MPI_Comm_rank(MPI_COMM_WORLD, &pRank);
if (pRank==0) {
for (int i=0; i<4; ++i) {
std::cout << data[i] << " ";
}
std::cout << std::endl;
}
MPI_Finalize();
return 1;
}
缺少的部分是从函数中的 lambda 获取函数指针的代码reduce
。从几个相关的问题来看,从捕获 lambda 获取函数指针的问题似乎很棘手,但可以解决。但是我没有在这个简单的代码上工作(我尝试了一些技巧与 std::function、std::bind、将 lambda 存储在静态变量中)......所以一点帮助会很棒!
编辑:在@noma 回答之后,我在 goldbolt 中尝试了以下没有 MPI 的简化代码
#include <iostream>
#include <functional>
typedef double MPI_Datatype;
template<typename BinaryOp, BinaryOp op> // older standards
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
{}
template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {
typedef void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * );
MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
}
int main() {
double data[4]={1.,2.,3.,4.};
reduce(data, 4, [](double a, double b){return a+b;});
return 1;
}
它在一些编译器上编译。结果如下:
- icc >= 19.0.1(使用-std=c++17):好的
- clang++ >= 5.0.0 (with --std=c++17): OK
- clang++ 10.0.0 (with --std=c++14): NOK
- g++ 9.3(使用--std=c++17):NOK
- icc >= 19.0.0 (with -std=c++17) : NOK
带有 -std=c++17 的 icc 19.0.0(或带有 -std=c++14 的 icc 19.0.1)的错误消息很有趣:
<source>(15): error: expression must have a constant value
MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
^
detected during instantiation of "void reduce(double *, int, BinaryOp) [with BinaryOp=lambda [](double, double)->double]" at line 21
事实上,我并不真正理解作为函数reduce
的第二个模板参数的函数的运行时参数“op”变量的传递non_lambda
......它是一个不起眼的 c++17 功能,只有一些编译器支持?
解决方案
我认为 lambda 方法在这里是不可能的,因为它是一个捕获 lambda,请参阅 https://stackoverflow.com/a/28746827/7678171
我们可以在BinaryOp
这里使用带有模板值参数的函数模板而不是 Lambda。这假设BinaryOp
是一个函数指针,或者是一个无捕获的 lambda,可以转换为一个。我们介绍的lambda
是reduce
:
template<auto op> // this is C++17, so use --std=c++17
// template<typename BinaryOp, BinaryOp op> // older standards
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
{
double *aa=static_cast<double *>(a);
double *bb=static_cast<double *>(bb);
for (int i=0; i<*len; ++i) {
bb[i]=op(aa[i], bb[i]);
}
}
那么黑魔法线是:
/* black magic code that get the function pointer from the lambda */
MPI_User_function *opPtr = &non_lambda<+op>; // NOTE: the + implies the lamda to function pointer conversion here
// MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
希望这可以帮助。
注意:我使用 Clang 6.0 编译了这个,但 g++ 7.5 失败(可能是编译器错误?):
error: no matches converting function ‘non_lambda’ to type ‘void (*)(void*, void*, int*, struct ompi_datatype_t**)’
MPI_User_function *opPtr = &non_lambda<+op>;
^~~~~
note: candidate is: template<auto op> void non_lambda(void*, void*, int*, ompi_datatype_t**)
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
也许较新的 g++ 版本可以工作。
推荐阅读
- python - 正则表达式将(所有匹配项或无匹配项)最后修复为一个
- javascript - DirtyForms 不适用于输入隐藏字段
- sql - 如何在从一个表插入另一个表之前修改选择查询的数据
- python-3.x - Python 会自动跳出循环吗?
- flutter - 使用存储和 cookie(不使用 Firebase)在 Flutter 中进行身份验证?
- r - 使用 combine() 和 aperm() 函数进行训练——第一个参数的问题,必须是一个数组
- python - tkinter
Windows 和 Linux 上的不同行为 - javascript - 通用文字类型
- python - Discord.py 添加反应
- javascript - 分配给变量的 innerHTML 不能实时更新?