c++ - 如何用可变参数函数覆盖 C++ 类中的运算符?
问题描述
这里的 C++ 新手:我想创建一个模板类来创建不同数据类型和d
维度的张量,其中d
由形状指定。例如,具有形状的张量(2, 3, 5)
具有 3 个维度,包含 24 个元素。我使用一维向量存储所有数据元素,并希望使用形状信息访问元素以查找元素。
我想覆盖()
运算符以访问元素。由于维度可以变化,因此()
操作员的输入参数的数量也可以变化。从技术上讲,我可以使用向量作为输入参数,但 C++ 似乎也支持可变参数函数。但是,我无法绕过它。
到目前为止我所拥有的:
#ifndef TENSOR_HPP
#define TENSOR_HPP
#include <vector>
#include <numeric>
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <stdarg.h>
template <typename T> class Tensor {
private:
std::vector<T> m_data;
std::vector<std::size_t> m_shape;
std::size_t m_size;
public:
// Constructors
Tensor(std::vector<T> data, std::vector<std::size_t> shape);
// Destructor
~Tensor();
// Access the individual elements
T& operator()(std::size_t&... d_args);
};
template <typename T> Tensor<T>::Tensor(std::vector<T> data, std::vector<std::size_t> shape) {
// Calculate number of data values based on shape
m_size = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<std::size_t>());
// Check if calculated number of values match the actual number
if (data.size() != m_size) {
throw std::length_error("Tensor shape does not match the number of data values");
}
// All good from here
m_data = data;
m_shape = shape;
}
template <typename T> T& Tensor<T>::operator() (std::size_t&... d_args) {
// Return something to avoid warning
return m_data[0];
};
template <typename T> Tensor<T>::~Tensor() {
//delete[] m_values;
};
#endif
不,当我执行以下操作时:
std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
std::vector<std::size_t> shape = {2, 3, 4};
Tensor<float> tensor(data, shape);
tensor(2,0,3); // <-- What I would like to do
// Possible workaround with vector which I would like to avoid
// std::vector<std::size_t> index = {2,0,3};
// tensor(index);
我得到错误:
tensor2.hpp:27:33: error: expansion pattern ‘std::size_t&’ {aka ‘long unsigned int&’} contains no parameter packs
()
使用可变参数函数覆盖运算符的正确方法是什么?
解决方案
您可以添加具有尽可能多的重载的辅助函数,以计算正确的索引以访问项目:
T& getData(int dim1) { return m_data[dim1];}
T& getData(int dim1, int dim2) { return m_data[ dim1* m_shape[1] + dim2 ];}
T& getData(int dim1, int dim2, int dim3) { return m_data[ dim1*m_shape[1]*m_shape[2] + dim2*m_shape[2] + dim3 ];}
那么operator()
可能看起来像:
template<class ... Args>
T& operator()(Args... d_args) {
static_assert( (std::is_integral_v<Args> && ...) ); // [1]
return getData(d_args...);
}
通过 [1] 我们限制()
仅使用整数类型。
推荐阅读
- xcode - Xcode 服务器:xcscontrol 命令总是失败
- javascript - 查找矩形的未使用空间
- .net-core - .Net 或 .Net Core 中是否有等效的 Spring Cloud Stream?
- c++ - Protobuf:了解 proto 文件的编译输出
- javascript - 不能在父级的 [ngClass] 逻辑中使用子级输入值
- ruby-on-rails - rails capistrano deploy 成功,但我无法访问我的部署 rails 主页
- c# - C# LibreOffice 进程不等待退出
- bash - 为什么 pkill 没有杀死命令 python blink_led.py&
- r - 如何重命名数据框中的单个列
- c# - 试图让玩家旋转垂直于模型的网格。统一