首页 > 解决方案 > 从 C++ lambda 函数创建 MPI 约简运算符

问题描述

我想编写一个包装 MPI_Allreduce 的函数,它接受任何二元运算符(如 std::reduce)作为 MPI 的归约运算符。特别是,这种函数的用户可以使用 lambda。

以下简单的示例代码说明了这一点:

#include <mpi.h>
#include <iostream>
#include <functional>

template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {

  auto lambda=[op](void *a, void *b, int *len, MPI_Datatype *){
                double *aa=static_cast<double *>(a);
                double *bb=static_cast<double *>(bb);
                for (int i=0; i<*len; ++i) {
                  bb[i]=op(aa[i], bb[i]);
                }
              };


  // MPI_User_function is a typedef to: void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * )
  MPI_User_function *opPtr=/* black magic code that get the function pointer from the lambda */;
  MPI_Op mpiOp;
  MPI_Op_create(*opPtr, 1, &mpiOp);
  MPI_Allreduce(MPI_IN_PLACE, data, len, MPI_DOUBLE, mpiOp, MPI_COMM_WORLD);
  MPI_Op_free(&mpiOp);
}

int main() {

  MPI_Init(nullptr, nullptr);

  double data[4]={1.,2.,3.,4.};

  reduce(data, 4, [](double a, double b){return a+b;});

  int pRank;
  MPI_Comm_rank(MPI_COMM_WORLD, &pRank);
  if (pRank==0) {
    for (int i=0; i<4; ++i) {
      std::cout << data[i] << " ";
    }
    std::cout << std::endl;
  }

  MPI_Finalize();

  return 1;
}

缺少的部分是从函数中的 lambda 获取函数指针的代码reduce。从几个相关的问题来看,从捕获 lambda 获取函数指针的问题似乎很棘手,但可以解决。但是我没有在这个简单的代码上工作(我尝试了一些技巧与 std::function、std::bind、将 lambda 存储在静态变量中)......所以一点帮助会很棒!


编辑:在@noma 回答之后,我在 goldbolt 中尝试了以下没有 MPI 的简化代码

#include <iostream>
#include <functional>

typedef double MPI_Datatype;

template<typename BinaryOp, BinaryOp op> // older standards
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
{}


template<typename BinaryOp>
void reduce(double *data, int len, BinaryOp op) {

  typedef void (MPI_User_function) ( void * a, void * b, int * len, MPI_Datatype * );
  MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;
}

int main() {

  double data[4]={1.,2.,3.,4.};  
  reduce(data, 4, [](double a, double b){return a+b;});

  return 1;
}

它在一些编译器上编译。结果如下:

带有 -std=c++17 的 icc 19.0.0(或带有 -std=c++14 的 icc 19.0.1)的错误消息很有趣:

   <source>(15): error: expression must have a constant value

MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;

                                                      ^

      detected during instantiation of "void reduce(double *, int, BinaryOp) [with BinaryOp=lambda [](double, double)->double]" at line 21               

事实上,我并不真正理解作为函数reduce的第二个模板参数的函数的运行时参数“op”变量的传递non_lambda......它是一个不起眼的 c++17 功能,只有一些编译器支持?

标签: c++lambdampifunction-pointers

解决方案


我认为 lambda 方法在这里是不可能的,因为它是一个捕获 lambda,请参阅 https://stackoverflow.com/a/28746827/7678171

我们可以在BinaryOp这里使用带有模板值参数的函数模板而不是 Lambda。这假设BinaryOp是一个函数指针,或者是一个无捕获的 lambda,可以转换为一个。我们介绍的lambdareduce

template<auto op> // this is C++17, so use --std=c++17
// template<typename BinaryOp, BinaryOp op> // older standards
void non_lambda(void *a, void *b, int *len, MPI_Datatype *)
{
    double *aa=static_cast<double *>(a);
    double *bb=static_cast<double *>(bb);
    for (int i=0; i<*len; ++i) {
        bb[i]=op(aa[i], bb[i]);
    }
}

那么黑魔法线是:

/* black magic code that get the function pointer from the lambda */
MPI_User_function *opPtr = &non_lambda<+op>; // NOTE: the + implies the lamda to function pointer conversion here
// MPI_User_function *opPtr = &non_lambda<decltype(+op), +op>; // older standards;

希望这可以帮助。


注意:我使用 Clang 6.0 编译了这个,但 g++ 7.5 失败(可能是编译器错误?):

error: no matches converting function ‘non_lambda’ to type ‘void (*)(void*, void*, int*, struct ompi_datatype_t**)’
   MPI_User_function *opPtr = &non_lambda<+op>;
                      ^~~~~
note: candidate is: template<auto op> void non_lambda(void*, void*, int*, ompi_datatype_t**)
 void non_lambda(void *a, void *b, int *len, MPI_Datatype *)

也许较新的 g++ 版本可以工作。


推荐阅读