首页 > 解决方案 > C++ 有限差分微分 - 设计

问题描述

A

class A {
  std::vector<double> values_;
public:
  A(const std::vector<double> &values) : values_(values){};

  void bumpAt(std::size_t i, const double &df) {
    values_[i] += df;

  virtual method1();
  virtual method2();
  ...
}

class B : public A {
  overrides methods
  ...
}

为简单起见,请考虑以下功能:

double foo(input1, input2, ..., const A &a, const B &b, inputK, ...) {
 /* do complex stuff */
 return ...;
}

我们想区分foo()它的论点。因此,一阶灵敏度d foo/d a是一个std::vector<double>大小等于a.size()。同样的道理也适用d foo/d b

一个简单的实现如下:


// compute d foo/d a 
std::vector<double> computeDfDa(input1, input2, ..., const A &a, const B &b, inputK, ..., double da = 1.0){
  std::vector<double> dfda = {};

  auto aUp = a.copy();
  auto aDown = a.copy();
  for (auto i = 0; i < a.size(); ++i) {

      // bump up
      aUp.bumpAt(i, da);
      // bump down
      aDown.bumpAt(i, -da);

      auto up = foo(input1, input2, ..., aUp, b, inputK, ...);
      auto down = foo(input1, input2, ..., aDown, b, inputK, ...);

      auto derivative = (up - down) / 2.0 / da;
      dfda.pushback(derivative);

      // revert bumps
      aUp.bumpAt(i, -da);
      aDown.bumpAt(i, da);
  }
   return dfda;
}

// compute d foo/d b 
std::vector<double> computeDfDb(input1, input2, ..., const A &a, const B &b, inputK, ..., double db = 0.01){
  std::vector<double> dfdb = {};

  auto bUp = b.copy();
  auto bDown = b.copy();
  for (auto i = 0; i < a.size(); ++i) {

      // bump up
      bUp.bumpAt(i, db);
      // bump down
      bDown.bumpAt(i, -db);

      auto up = foo(input1, input2, ..., a, bUp, inputK, ...);
      auto down = foo(input1, input2, ..., a, bDown, inputK, ...);

      auto derivative = (up - down) / 2.0 / db;
      dfdb.pushback(derivative);

      // revert bumps
      bUp.bumpAt(i, -db);
      bDown.bumpAt(i, db);
  }
   return dfdb;
}

这很好用,但是我们有基本相同的代码 forcomputeDfDa()和 for computeDfDb()

是否有任何设计模式允许拥有一个独特的(可能是模板化的)函数来自动理解要碰撞的输入?

请注意ab在输入中的位置不是可交换的。

如果 的复杂性和输入的数量foo()更大,那么简单的解决方案将生成大量无用的代码,因为我们必须为 . 的每个输入编写一个computeDfDx()函数。xfoo()

标签: c++templatesdesign-patternstraitsexpression-templates

解决方案


由于compute顺序相同但迭代循环通过不同的容器,您可以重构此函数。

std::vector<double> computeLoop( std::vector<double> &v, std::vector<double> const &computeArg1, std::vector<double> const &computeArg2, double d = 1.0 )
{
  std::vector<double> dfd = {};
  for (auto i = 0; i < v.size(); ++i) {
      // bump up
      v[i] += d;
      auto up = compute(computeArg1, computeArg2);
      v[i] -= d;

      // bump down
      v[i] -= d;
      auto down = compute(computeArg1, computeArg2);
      v[i] += d;

      auto derivative = (up - down) / 2.0 / d;
      dfd.pushback(derivative);
  }
}

实际通话。

auto dfda = computeLoop( a, a, b );
auto dfdb = computeLoop( b, a, b );

让我们v通过引用传递,但它可能会导致维护问题。因为可能与orv是相同的引用,但是在这方面并不明显。将来有人可能会不自觉地破坏代码。computeArg1computeArg2computeLoop


推荐阅读