首页 > 解决方案 > 如何用可变参数函数覆盖 C++ 类中的运算符?

问题描述

这里的 C++ 新手:我想创建一个模板类来创建不同数据类型和d维度的张量,其中d由形状指定。例如,具有形状的张量(2, 3, 5)具有 3 个维度,包含 24 个元素。我使用一维向量存储所有数据元素,并希望使用形状信息访问元素以查找元素。

我想覆盖()运算符以访问元素。由于维度可以变化,因此()操作员的输入参数的数量也可以变化。从技术上讲,我可以使用向量作为输入参数,但 C++ 似乎也支持可变参数函数。但是,我无法绕过它。

到目前为止我所拥有的:

#ifndef TENSOR_HPP
#define TENSOR_HPP

#include <vector>
#include <numeric>
#include <algorithm>
#include <stdexcept>
#include <iostream>
#include <stdarg.h>


template <typename T> class Tensor {

    private:
        std::vector<T> m_data;
        std::vector<std::size_t> m_shape;
        std::size_t m_size;
        
    public:
        // Constructors
        Tensor(std::vector<T> data, std::vector<std::size_t> shape);

        // Destructor
        ~Tensor();

        // Access the individual elements                                                                                                                                                                                               
        T& operator()(std::size_t&... d_args);
        
};


template <typename T> Tensor<T>::Tensor(std::vector<T> data, std::vector<std::size_t> shape) {
    // Calculate number of data values based on shape
    m_size = std::accumulate(std::begin(shape), std::end(shape), 1, std::multiplies<std::size_t>());
    // Check if calculated number of values match the actual number
    if (data.size() != m_size) {
        throw std::length_error("Tensor shape does not match the number of data values");
    } 
    // All good from here
    m_data = data;
    m_shape = shape;
}

template <typename T> T& Tensor<T>::operator() (std::size_t&... d_args) {
    // Return something to avoid warning
    return m_data[0];
};

template <typename T> Tensor<T>::~Tensor() {
    //delete[] m_values;
};


#endif

不,当我执行以下操作时:

std::vector<float> data = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24};
std::vector<std::size_t> shape = {2, 3, 4};
Tensor<float> tensor(data, shape);

tensor(2,0,3); // <-- What I would like to do

// Possible workaround with vector which I would like to avoid
// std::vector<std::size_t> index = {2,0,3};
// tensor(index);

我得到错误:

tensor2.hpp:27:33: error: expansion pattern ‘std::size_t&’ {aka ‘long unsigned int&’} contains no parameter packs

()使用可变参数函数覆盖运算符的正确方法是什么?

标签: c++variadic-templatesvariadic-functions

解决方案


您可以添加具有尽可能多的重载的辅助函数,以计算正确的索引以访问项目:

    T& getData(int dim1) { return m_data[dim1];}
    T& getData(int dim1, int dim2) { return m_data[ dim1* m_shape[1] + dim2 ];}
    T& getData(int dim1, int dim2, int dim3) { return m_data[ dim1*m_shape[1]*m_shape[2] + dim2*m_shape[2] + dim3 ];}

那么operator()可能看起来像:

    template<class ... Args>                                                                                                                                                                                           
    T& operator()(Args... d_args) {
        static_assert( (std::is_integral_v<Args> && ...) ); // [1]
        return getData(d_args...);
    }

通过 [1] 我们限制()仅使用整数类型。

现场演示


推荐阅读