c++ - C++ 有限差分微分 - 设计
问题描述
让A
:
class A {
std::vector<double> values_;
public:
A(const std::vector<double> &values) : values_(values){};
void bumpAt(std::size_t i, const double &df) {
values_[i] += df;
virtual method1();
virtual method2();
...
}
class B : public A {
overrides methods
...
}
为简单起见,请考虑以下功能:
double foo(input1, input2, ..., const A &a, const B &b, inputK, ...) {
/* do complex stuff */
return ...;
}
我们想区分foo()
它的论点。因此,一阶灵敏度d foo/d a
是一个std::vector<double>
大小等于a.size()
。同样的道理也适用d foo/d b
。
一个简单的实现如下:
// compute d foo/d a
std::vector<double> computeDfDa(input1, input2, ..., const A &a, const B &b, inputK, ..., double da = 1.0){
std::vector<double> dfda = {};
auto aUp = a.copy();
auto aDown = a.copy();
for (auto i = 0; i < a.size(); ++i) {
// bump up
aUp.bumpAt(i, da);
// bump down
aDown.bumpAt(i, -da);
auto up = foo(input1, input2, ..., aUp, b, inputK, ...);
auto down = foo(input1, input2, ..., aDown, b, inputK, ...);
auto derivative = (up - down) / 2.0 / da;
dfda.pushback(derivative);
// revert bumps
aUp.bumpAt(i, -da);
aDown.bumpAt(i, da);
}
return dfda;
}
// compute d foo/d b
std::vector<double> computeDfDb(input1, input2, ..., const A &a, const B &b, inputK, ..., double db = 0.01){
std::vector<double> dfdb = {};
auto bUp = b.copy();
auto bDown = b.copy();
for (auto i = 0; i < a.size(); ++i) {
// bump up
bUp.bumpAt(i, db);
// bump down
bDown.bumpAt(i, -db);
auto up = foo(input1, input2, ..., a, bUp, inputK, ...);
auto down = foo(input1, input2, ..., a, bDown, inputK, ...);
auto derivative = (up - down) / 2.0 / db;
dfdb.pushback(derivative);
// revert bumps
bUp.bumpAt(i, -db);
bDown.bumpAt(i, db);
}
return dfdb;
}
这很好用,但是我们有基本相同的代码 forcomputeDfDa()
和 for computeDfDb()
。
是否有任何设计模式允许拥有一个独特的(可能是模板化的)函数来自动理解要碰撞的输入?
请注意a
和b
在输入中的位置不是可交换的。
如果 的复杂性和输入的数量foo()
更大,那么简单的解决方案将生成大量无用的代码,因为我们必须为 . 的每个输入编写一个computeDfDx()
函数。x
foo()
解决方案
由于compute
顺序相同但迭代循环通过不同的容器,您可以重构此函数。
std::vector<double> computeLoop( std::vector<double> &v, std::vector<double> const &computeArg1, std::vector<double> const &computeArg2, double d = 1.0 )
{
std::vector<double> dfd = {};
for (auto i = 0; i < v.size(); ++i) {
// bump up
v[i] += d;
auto up = compute(computeArg1, computeArg2);
v[i] -= d;
// bump down
v[i] -= d;
auto down = compute(computeArg1, computeArg2);
v[i] += d;
auto derivative = (up - down) / 2.0 / d;
dfd.pushback(derivative);
}
}
实际通话。
auto dfda = computeLoop( a, a, b );
auto dfdb = computeLoop( b, a, b );
让我们v
通过引用传递,但它可能会导致维护问题。因为可能与orv
是相同的引用,但是在这方面并不明显。将来有人可能会不自觉地破坏代码。computeArg1
computeArg2
computeLoop
推荐阅读
- vb.net - AVL树如何在插入时平衡树
- c# - 使用 timeScale 0f 暂停游戏时如何暂停/恢复协程?
- git - 如何更改提交编辑器中显示的内容?
- c++ - 是否有一种解决方法可以在 c++ 中为短裤定义用户定义的文字?
- bash - 比较 2 个 csv 文件并将常用字段记录合并到第一个文件中
- javascript - Vue 组件通过上下文获取 props
- html - 带间距的 CSS flexbox 装订线
- flutter - 是否可以在颤振中的两个提供者之间进行交互和交换数据?
- csv - Spark:导入异构多个csv
- github - 如何将 git 凭据提供给 Jenkins 管道步骤,以便任何 shell/terraform 代码可以在运行脚本时执行 git clone