c++ - Eigen ConditionType 数组:广播而不是循环的有效方式
问题描述
我有一段性能关键的代码,我需要检查一个数组是否有低于阈值的值,然后有条件地设置另外两个数组的值。我的代码如下所示:
#include <Eigen/Dense>
int main(){
Eigen::ArrayXXd
a (1, 100),
b (2, 100),
c (3, 100);
a.setRandom();
b.setRandom();
c.setRandom();
constexpr double minVal { 1e-8 };
/* the code segment in question */
/* option 1 */
for ( int i=0; i<2; ++i ){
b.row(i) = (a < minVal).select( 0, c.row(i+1) / a );
c.row(i+1) = (a < minVal).select( 0, c.row(i+1) );
}
/* option 2, which is slower */
b = (a < minVal).replicate(2,1).select( 0, c.bottomRows(2) / a.replicate(2,1) );
c.bottomRows(2) = (a < minVal).replicate(2,1).select( 0, c.bottomRows(2) );
return 0;
}
a
检查其值是否达到阈值的数组minVal
具有一行和动态数量的列。其他两个数组b
分别c
有两行和三行,列数与 相同a
。
现在我想以更多eigen
方式执行上述逻辑,在选项 1 中不使用该循环,因为通常情况下,在eigen
性能方面有一些技巧,我在编写原始循环时永远无法匹配。但是,我能想到的唯一方法是选项 2,它明显比选项 1 慢。
执行上述操作的正确有效方法是什么?还是循环已经是我最好的选择?
解决方案
您可以尝试以下方法:
- 使用固定行数和动态列数定义数组类型,即,您可以将Eigen::ArrayXXd替换为Eigen::Array<double, 1/2/3, Eigen::Dynamic>。
- 使用固定大小版本的块操作(请参阅https://eigen.tuxfamily.org/dox/group__TutorialBlockOperations.html),即,您可以将bottomRows(N)替换为bottomRows<N>()并类似地复制(2,1 )与复制<2,1>()。
我已经更改了代码中的数组类型,并包含了第三个选项以及我提到的可能的改进:
#include <Eigen/Dense>
#include <iostream>
#include <chrono>
constexpr int numberOfTrials = 1000000;
constexpr double minVal{ 1e-8 };
typedef Eigen::Array<double, 1, Eigen::Dynamic> Array1Xd;
typedef Eigen::Array<double, 2, Eigen::Dynamic> Array2Xd;
typedef Eigen::Array<double, 3, Eigen::Dynamic> Array3Xd;
inline void option1(const Array1Xd& a, Array2Xd& b, Array3Xd& c)
{
for (int i = 0; i < 2; ++i) {
b.row(i) = (a < minVal).select(0, c.row(i + 1) / a);
c.row(i + 1) = (a < minVal).select(0, c.row(i + 1));
}
}
inline void option2(const Array1Xd& a, Array2Xd& b, Array3Xd& c)
{
b = (a < minVal).replicate(2, 1).select(0, c.bottomRows(2) / a.replicate(2, 1));
c.bottomRows(2) = (a < minVal).replicate(2, 1).select(0, c.bottomRows(2));
}
inline void option3(const Array1Xd& a, Array2Xd& b, Array3Xd& c)
{
b = (a < minVal).replicate<2, 1>().select(0, c.bottomRows<2>() / a.replicate<2, 1>());
c.bottomRows<2>() = (a < minVal).replicate<2, 1>().select(0, c.bottomRows<2>());
}
int main() {
Array1Xd a(1, 100);
Array2Xd b(2, 100);
Array3Xd c(3, 100);
a.setRandom();
b.setRandom();
c.setRandom();
auto tpBegin1 = std::chrono::steady_clock::now();
for (int i = 0; i < numberOfTrials; i++)
option1(a, b, c);
auto tpEnd1 = std::chrono::steady_clock::now();
auto tpBegin2 = std::chrono::steady_clock::now();
for (int i = 0; i < numberOfTrials; i++)
option2(a, b, c);
auto tpEnd2 = std::chrono::steady_clock::now();
auto tpBegin3 = std::chrono::steady_clock::now();
for (int i = 0; i < numberOfTrials; i++)
option3(a, b, c);
auto tpEnd3 = std::chrono::steady_clock::now();
std::cout << "(Option 1) Average execution time: " << std::chrono::duration_cast<std::chrono::microseconds>(tpEnd1 - tpBegin1).count() / (long double)(numberOfTrials) << " us" << std::endl;
std::cout << "(Option 2) Average execution time: " << std::chrono::duration_cast<std::chrono::microseconds>(tpEnd2 - tpBegin2).count() / (long double)(numberOfTrials) << " us" << std::endl;
std::cout << "(Option 3) Average execution time: " << std::chrono::duration_cast<std::chrono::microseconds>(tpEnd3 - tpBegin3).count() / (long double)(numberOfTrials) << " us" << std::endl;
return 0;
}
我获得的平均执行时间如下(i7-9700K,msvc2019,启用优化,NDEBUG):
(Option 1) Average execution time: 0.527717 us
(Option 2) Average execution time: 3.25618 us
(Option 3) Average execution time: 0.512029 us
并启用 AVX2+OpenMP:
(Option 1) Average execution time: 0.374309 us
(Option 2) Average execution time: 3.31356 us
(Option 3) Average execution time: 0.260551 us
我不确定这是否是最“本征”的方式,但我希望它有所帮助!
推荐阅读
- python - 如何检查平铺物理和与精灵动画的碰撞?
- python - 从字典中的值中找到熊猫的平均值?
- javascript - index.js:1375 警告:列表中的每个孩子都应该有一个唯一的“关键”道具
- javascript - JS 控件在第二个滑块上不起作用,不确定如何定位
- sql - 关联不相关表中的值
- initialization - 使用 init 脚本在 dbfs 和 mvn 包中安装 jar 文件
- osgi - 导入包:com.day.cq.wcm.api,version=[1.29,2) 和 org.apache.sling.api.resource,version=[2.12,3) -- 无法解决
- sql - 计算双引号和括号之间的值
- sql - 从 BigQuery 的表中选择几行数据的简单方法?
- azure-data-explorer - 使用查询填充 KQL 中的扩展?