c++ - 从整数值到类型的动态转换(C++11 模板元编程?)
问题描述
我正在尝试通过模板减少代码重复。我已经将大部分代码移到了这个帮助iterate_function_from_CSC_helper
器,它现在是一个模板。然而,这个函数仍然重复了很多代码,只是为了调用模板的不同特化:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (data_type == C_API_DTYPE_FLOAT32) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<float, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<float, int64_t>(col_ptr, indices, data, col_idx);
}
} else if (data_type == C_API_DTYPE_FLOAT64) {
if (col_ptr_type == C_API_DTYPE_INT32) {
return iterate_function_from_CSC_helper<double, int32_t>(col_ptr, indices, data, col_idx);
} else if (col_ptr_type == C_API_DTYPE_INT64) {
return iterate_function_from_CSC_helper<double, int64_t>(col_ptr, indices, data, col_idx);
}
}
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
我想自动将运行时收到的整数映射到 float/double 和 int32_t/int64_t 类型,并用它们调用模板data_type
。col_ptr_dtype
像这样的东西:
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr, int col_ptr_type, const int32_t* indices, const void* data, int data_type, int64_t ncol_ptr, int64_t , int col_idx) {
CHECK(col_idx < ncol_ptr && col_idx >= 0);
if (<TTag<data_col>::invalid_type || TTag<col_ptr_type>::invalid_type) {
Log::Fatal("Unknown data type in CSC matrix");
return nullptr;
}
return iterate_function_from_CSC_helper<TTag<data_type>::type, TTag<col_ptr_type>::type>(col_ptr, indices, data, col_idx);
}
那可能吗?我假设通过一些元编程可以消除这种情况。
我尝试了以下方法,但无法dummy_IterateFunctionFromCSC
使用非 const 输入(在运行时会出现这种情况):
#include <cstdint>
#include <stdio.h>
#include <iostream>
#include <type_traits>
#define C_API_DTYPE_FLOAT32 (0) /*!< \brief float32 (single precision float). */
#define C_API_DTYPE_FLOAT64 (1) /*!< \brief float64 (double precision float). */
#define C_API_DTYPE_INT32 (2) /*!< \brief int32. */
#define C_API_DTYPE_INT64 (3) /*!< \brief int64. */
struct TTagInvalidType {}; //! Meant for invalid types in TTag.
template <int C_API_DTYPE>
struct TTag {
using type = TTagInvalidType;
};
template<>
struct TTag<C_API_DTYPE_FLOAT32> {
using type = float;
};
template <>
struct TTag<C_API_DTYPE_FLOAT64> {
using type = double;
};
template <>
struct TTag<C_API_DTYPE_INT32> {
using type = int32_t;
};
template <>
struct TTag<C_API_DTYPE_INT64> {
using type = int64_t;
};
template <typename T>
void example_f () {
T x = 3.6;
std::cout << x << "\n";
}
template <>
void example_f<TTagInvalidType>() {
std::cout << "Abort!\n";
}
template<int x>
void dummy_IterateFunctionFromCSC() {
f<typename TTag<x>::type>();
}
int main() {
const int m = 2; // Doesn't work for non const integers (true at runtime)
dummy_IterateFunctionFromCSC<m>();
}
这可以编译,但只能使用常量 m,而不是例如从用户接收到的整数。
这是不可能的,因为类型调度必须在编译时计算吗?或者有可能以及如何?:D
谢谢 :)
解决方案
将运行时值转换为编译时值确实需要一些 if/switch 就像你做的那样。
您可以通过额外的拆分来避免一些重复:
C++17 可能有助于减少std::variant
一些实用程序的冗长:
template <typename T> struct type_identity { using type = T; };
// type should be an enum
std::variant<type_identity<int32_t>, type_identity<int64_t>> to_compile_int_type(int type)
{
switch (type) {
case C_API_DTYPE_INT32: return type_identity<int32_t>{};
case C_API_DTYPE_INT64: return type_identity<int64_t>{};
default:
Log::Fatal("Unknown int data type");
throw "unknown type";
}
}
// type should be an enum
std::variant<type_identity<float>, type_identity<double>> to_compile_float_type(int type)
{
switch (type) {
case C_API_DTYPE_FLOAT32: return type_identity<float>{};
case C_API_DTYPE_FLOAT64: return type_identity<double>{};
default:
Log::Fatal("Unknown float data type");
throw "unknown type";
}
}
进而
std::function<std::pair<int, double>(int idx)>
IterateFunctionFromCSC(const void* col_ptr,
int col_ptr_type,
const int32_t* indices,
const void* data,
int data_type,
int64_t ncol_ptr,
int64_t ,
int col_idx)
{
CHECK(col_idx < ncol_ptr && col_idx >= 0);
std::visit(
[&](auto intvar, auto floatvar){
using inttype = typename decltype(intvar)::type;
using floattype = typename decltype(floatvar)::type;
return iterate_function_from_CSC_helper<floatype, inttype>(col_ptr, indices, data, col_idx);
},
to_compile_int_type(col_ptr_type),
to_compile_float_type(data_type)
);
}
推荐阅读
- java - 尝试访问 pojo 时出现 NoMessageBodyWriterFoundFailure
- azure - Azure - 带有正则表达式的 Web API 和路由模板
- python - 如何手动安装 Numpy (Linux)?
- css - CSS Grid 将主页分成两部分
- docker - traefik docker 标签中基于主机和路径的路由规则的混合
- html - 如何根据悬停效果设置背景颜色,即当鼠标悬停在为类的背景颜色定义的相同 css 定义中的 li 项目上时?
- python - 对一个列表进行排序以使两个列表具有正确的顺序对应关系
- knockout.js - Knockout.js foreach 没有给出错误并且没有显示结果
- python - 如何在 pyspark 列表达式中引用名称中带有连字符的列?
- c++ - 如何访问子类中的小部件?