首页 > 解决方案 > 接受特征密集矩阵和稀疏矩阵的函数

问题描述

我正在努力将稀疏矩阵支持添加到开源数学库中,并且不希望同时具有矩阵类型DenseSparse矩阵类型的重复函数。

下面的示例显示了一个add函数。一个具有两个功能的工作示例,然后两次尝试失败。下面提供了指向代码示例的神螺栓链接。

我查看了关于编写采用 Eigen 类型的函数的 Eigen 文档,但它们的使用答案Eigen::EigenBase不起作用,因为两者都有可用的特定方法,MatrixBase而这些方法在SparseMatrixBaseEigenBase

https://eigen.tuxfamily.org/dox/TopicFunctionTakingEigenTypes.html

我们使用 C++14,非常感谢您的帮助和您的时间!!

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>

// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;

// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;

// Below two are the generics that work
template <class Derived>
eigen_return_t<Derived> add(const Eigen::MatrixBase<Derived>& A) {
    return A + A;
}

template <class Derived>
eigen_return_t<Derived> add(const Eigen::SparseMatrixBase<Derived>& A) {
    return A + A;
}

int main()
{
  // Fill up the sparse and dense matrices
  tripletList.reserve(4);
  tripletList.push_back(triplet_d(0, 0, 1));
  tripletList.push_back(triplet_d(0, 1, 2));
  tripletList.push_back(triplet_d(1, 0, 3));
  tripletList.push_back(triplet_d(1, 1, 4));

  sparse_mat_d mat(2, 2);
  mat.setFromTriplets(tripletList.begin(), tripletList.end());

  Eigen::Matrix<double, -1, -1> v(2, 2);
  v << 1, 2, 3, 4;

  // Works fine
  sparse_mat_d output = add(mat * mat);
  std::cout << output;

  // Works fine
  Eigen::Matrix<double, -1, -1> output2 = add(v * v);
  std::cout << output2;

} 

而不是两个添加函数,我只想拥有一个同时接受稀疏和密集矩阵的函数,但是下面的尝试没有成功。

模板 模板类型

我的尝试显然很糟糕,但是用add模板模板类型替换上面的两个函数会导致模棱两可的基类错误。

template <template <class> class Container, class Derived>
Container<Derived> add(const Container<Derived>& A) {
    return A + A;    
}

错误:

<source>: In function 'int main()':
<source>:35:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
   35 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
   20 | Container<Derived> add(const Container<Derived>& A) {
      |                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:35:38: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>'
   35 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:40:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
   40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^
<source>:20:20: note: candidate: 'template<template<class> class Container, class Derived> Container<Derived> add(const Container<Derived>&)'
   20 | Container<Derived> add(const Container<Derived>& A) {
      |                    ^~~
<source>:20:20: note:   template argument deduction/substitution failed:
<source>:40:52: note:   'const Container<Derived>' is an ambiguous base class of 'const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>'
   40 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^

我相信这是同样的钻石继承问题:

https://www.fluentcpp.com/2017/05/19/crtp-helper/

使用 std::conditional_t

下面尝试使用conditional_t来推断正确的输入类型

#include <Eigen/Core>
#include <Eigen/Sparse>
#include <iostream>

// Sparse matrix helper
using triplet_d = Eigen::Triplet<double>;
using sparse_mat_d = Eigen::SparseMatrix<double>;
std::vector<triplet_d> tripletList;


// Returns plain object
template <typename Derived>
using eigen_return_t = typename Derived::PlainObject;

// Check it Object inherits from DenseBase
template<typename Derived>
using is_dense_matrix_expression = std::is_base_of<Eigen::DenseBase<std::decay_t<Derived>>, std::decay_t<Derived>>;

// Check it Object inherits from EigenBase
template<typename Derived>
using is_eigen_expression = std::is_base_of<Eigen::EigenBase<std::decay_t<Derived>>, std::decay_t<Derived>>;

// Alias to deduce if input should be Dense or Sparse matrix
template <typename Derived>
using eigen_matrix = typename std::conditional_t<is_dense_matrix_expression<Derived>::value,
 typename Eigen::MatrixBase<Derived>, typename Eigen::SparseMatrixBase<Derived>>;

template <typename Derived>
eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
    return A + A;
}

int main()
{
  tripletList.reserve(4);

  tripletList.push_back(triplet_d(0, 0, 1));
  tripletList.push_back(triplet_d(0, 1, 2));
  tripletList.push_back(triplet_d(1, 0, 3));
  tripletList.push_back(triplet_d(1, 1, 4));

  sparse_mat_d mat(2, 2);
  mat.setFromTriplets(tripletList.begin(), tripletList.end());
  sparse_mat_d output = add(mat * mat);

  std::cout << output;
  Eigen::Matrix<double, -1, -1> v(2, 2);
  v << 1, 2, 3, 4;
  Eigen::Matrix<double, -1, -1> output2 = add(v * v);
  std::cout << output2;

} 

这会引发错误

<source>: In function 'int main()':
<source>:94:38: error: no matching function for call to 'add(const Eigen::Product<Eigen::SparseMatrix<double, 0, int>, Eigen::SparseMatrix<double, 0, int>, 2>)'
   94 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
   79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
      |                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:94:38: note:   couldn't deduce template parameter 'Derived'
   94 |   sparse_mat_d output = add(mat * mat);
      |                                      ^
<source>:99:52: error: no matching function for call to 'add(const Eigen::Product<Eigen::Matrix<double, -1, -1>, Eigen::Matrix<double, -1, -1>, 0>)'
   99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);
      |                                                    ^
<source>:79:25: note: candidate: 'template<class Derived> eigen_return_t<Derived> add(eigen_matrix<Derived>&)'
   79 | eigen_return_t<Derived> add(const eigen_matrix<Derived>& A) {
      |                         ^~~
<source>:79:25: note:   template argument deduction/substitution failed:
<source>:99:52: note:   couldn't deduce template parameter 'Derived'
   99 |   Eigen::Matrix<double, -1, -1> output2 = add(v * v);

这似乎是因为无法像这个链接那样推断出依赖类型的依赖参数。

https://deque.blog/2017/10/12/why-template-parameters-of-dependent-type-names-cannot-be-deduced-and-what-to-do-about-it/

螺栓示例

下面的godbolt可以使用上面的所有实例

https://godbolt.org/z/yKEAsn

有没有办法只有一个功能而不是两个?我们有很多函数可以同时支持稀疏矩阵和稠密矩阵,因此避免代码重复会很好。

编辑:可能的答案

@Max Langhof 建议使用

template <class Mat>
auto add(const Mat& A) {
 return A + A; 
}

Eigen的auto关键字有点危险

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

template <class Mat> 
typename Mat::PlainObject add(const Mat& A) { 
    return A + A; 
}

有效,尽管我不完全确定为什么在这种情况下返回普通对象有效

编辑 编辑

有几个人提到了auto关键字的使用。遗憾的是,Eigen 不能很好地与autoC++11 的第二个中引用的以及下面链接中的 auto 一起使用

https://eigen.tuxfamily.org/dox/TopicPitfalls.html

在某些情况下可以使用 auto ,但我想看看是否有一种通用auto的“ish 方式可以抱怨 Eigen 的模板返回类型”

对于带有 auto 的段错误示例,您可以尝试将 add 替换为

template <typename T1>
auto add(const T1& A) 
{
    return ((A+A).eval()).transpose();
}

标签: c++templateseigenmultiple-inheritance

解决方案


如果你想通过EigenBase<Derived>,你可以提取底层类型使用.derived()(本质上,这只是强制转换为Derived const&):

template <class Derived>
eigen_return_t<Derived> add(const Eigen::EigenBase<Derived>& A_) {
    Derived const& A = A_.derived();
    return A + A;
}

更高级,对于这个特定的示例,由于您使用A了两次,您可以使用内部评估器结构来表达:

template <class Derived>
eigen_return_t<Derived> add2(const Eigen::EigenBase<Derived>& A_) {
    // A is used twice:
    typedef typename Eigen::internal::nested_eval<Derived,2>::type NestedA;
    NestedA A (A_.derived());
    return A + A;
}

这样做的好处是,当传递一个产品时,A_它不会在评估时被评估两次A+A,但如果A_是类似的东西,Block<...>它就不会被不必要地复制。但是,internal并不真正推荐使用功能(其 API 可能随时更改)。


推荐阅读