首页 > 解决方案 > 从整数值到类型的动态转换(C++11 模板元编程?)

问题描述

我正在尝试通过模板减少代码重复。我已经将大部分代码移到了这个帮助iterate_function_from_CSC_helper器,它现在是一个模板。然而,这个函数仍然重复了很多代码,只是为了调用模板的不同特化:

std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);
  if (data_type == C_API_DTYPE_FLOAT32) {
    if (col_ptr_type == C_API_DTYPE_INT32) {
      return iterate_function_from_CSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
      return iterate_function_from_CSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
    }    
  } else if (data_type == C_API_DTYPE_FLOAT64) {
    if (col_ptr_type == C_API_DTYPE_INT32) {
      return iterate_function_from_CSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
    } else if (col_ptr_type == C_API_DTYPE_INT64) {
      return iterate_function_from_CSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
    }
  }
  Log::Fatal("Unknown data type in CSC matrix");
  return nullptr;
}

我想自动将运行时收到的整数映射到 float/double 和 int32_t/int64_t 类型,并用它们调用模板data_typecol_ptr_dtype像这样的东西:

std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
  CHECK(col_idx < ncol_ptr && col_idx >= 0);

  if (<TTag<data_col>::invalid_type || TTag<col_ptr_type>::invalid_type) {
    Log::Fatal("Unknown data type in CSC matrix");
    return nullptr;
  }

  return iterate_function_from_CSC_helper<TTag<data_type>::type, TTag<col_ptr_type>::type>(col_ptr, indices, data, col_idx);

}

那可能吗?我假设通过一些元编程可以消除这种情况。

我尝试了以下方法,但无法dummy_IterateFunctionFromCSC使用非 const 输入(在运行时会出现这种情况):

#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <type_traits>


#define C_API_DTYPE_FLOAT32 (0)  /*!< \brief float32 (single precision float). */
#define C_API_DTYPE_FLOAT64 (1)  /*!< \brief float64 (double precision float). */
#define C_API_DTYPE_INT32   (2)  /*!< \brief int32. */
#define C_API_DTYPE_INT64   (3)  /*!< \brief int64. */



struct TTagInvalidType {}; //! Meant for invalid types in TTag.

template <int C_API_DTYPE>
struct TTag {
  using type = TTagInvalidType;  
};

template<>
struct TTag<C_API_DTYPE_FLOAT32> {
  using type = float;
};

template <>
struct TTag<C_API_DTYPE_FLOAT64> {
  using type = double;
};

template <>
struct TTag<C_API_DTYPE_INT32> {
  using type = int32_t;
};

template <>
struct TTag<C_API_DTYPE_INT64> {
  using type = int64_t;
};



template <typename T>
void example_f () {
    T x = 3.6;
    std::cout << x << "\n";
}

template <>
void example_f<TTagInvalidType>() {    
    std::cout << "Abort!\n";
}

template<int x>
void dummy_IterateFunctionFromCSC() {
    f<typename TTag<x>::type>();
}

int main() {
    const int m = 2;  // Doesn't work for non const integers (true at runtime)
    dummy_IterateFunctionFromCSC<m>();
}

这可以编译,但只能使用常量 m,而不是例如从用户接收到的整数。

这是不可能的,因为类型调度必须在编译时计算吗?或者有可能以及如何?:D

谢谢 :)

标签: c++c++11template-meta-programming

解决方案


将运行时值转换为编译时值确实需要一些 if/switch 就像你做的那样。

您可以通过额外的拆分来避免一些重复:

C++17 可能有助于减少std::variant一些实用程序的冗长:

template <typename T> struct type_identity { using type = T; };

// type should be an enum
std::variant<type_identity<int32_t>, type_identity<int64_t>> to_compile_int_type(int type)
{
    switch (type) {
        case C_API_DTYPE_INT32: return type_identity<int32_t>{};
        case C_API_DTYPE_INT64: return type_identity<int64_t>{};
        default:
            Log::Fatal("Unknown int data type");
            throw "unknown type";
    }
}

// type should be an enum
std::variant<type_identity<float>, type_identity<double>> to_compile_float_type(int type)
{
    switch (type) {
        case C_API_DTYPE_FLOAT32: return type_identity<float>{};
        case C_API_DTYPE_FLOAT64: return type_identity<double>{};
        default:
            Log::Fatal("Unknown float data type");
            throw "unknown type";
    }
}

进而

std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr,
                       int col_ptr_type,
                       const int32_t* indices,
                       const void* data,
                       int data_type,
                       int64_t ncol_ptr,
                       int64_t ,
                       int col_idx)
{
    CHECK(col_idx < ncol_ptr && col_idx >= 0);
    std::visit(
        [&](auto intvar, auto floatvar){
            using inttype = typename decltype(intvar)::type;
            using floattype = typename decltype(floatvar)::type;

            return iterate_function_from_CSC_helper<floatype, inttype>(col_ptr, indices, data, col_idx);
        },
        to_compile_int_type(col_ptr_type),
        to_compile_float_type(data_type)
    );
}

推荐阅读