首页 > 解决方案 > flatten_parameters() 有什么作用?

问题描述

我在 RNN 的前向函数中看到了许多使用 flatten_parameters 的 Pytorch 示例

self.rnn.flatten_parameters()

我看到了这个RNNBase,上面写着

重置参数数据指针,以便他们可以使用更快的代码路径

这意味着什么?

标签: pytorch

解决方案


它可能不是您问题的完整答案。但是,如果你看一下flatten_parameters的源代码,你会注意到它调用_cudnn_rnn_flatten_weight

...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...

是完成这项工作的功能。您会发现它实际上所做的是将模型的权重复制到vector<Tensor>(检查params_arr声明)中:

  // Slice off views into weight_buf
  std::vector<Tensor> params_arr;
  size_t params_stride0;
  std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);

  MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
                    params{params_arr, params_stride0};

权重复制进来

  // Copy weights
  _copyParams(weight, params);

另请注意,Reset它们通过在weightsparams.set__orig_param.set_(new_param.view_as(orig_param));

  // Update the storage
  for (size_t i = 0; i < weight.size(0); i++) {
    for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
         orig_param_it != weight[i].end() && new_param_it != params[i].end();
         orig_param_it++, new_param_it++) {
      auto orig_param = *orig_param_it, new_param = *new_param_it;
      orig_param.set_(new_param.view_as(orig_param));
    }
  }

并根据n2798(C++0x 草案)

©ISO/IECN3092

23.3.6 类模板向量

向量是支持随机访问迭代器的序列容器。此外,它还支持(摊销)恒定时间的最后插入和擦除操作;在中间插入和擦除需要线性时间。存储管理是自动处理的,但可以给出提示以提高效率。向量的元素是连续存储的,这意味着如果v是一个向量<T, Allocator>,其中T的类型不是 bool,那么它服从identity&v[n] == &v[0] + nfor all 0 <= n < v.size()


在某些情况下

用户警告:RNN 模块权重不是单个连续内存块的一部分。这意味着它们需要在每次调用时进行压缩,可能会大大增加内存使用量。再次调用压缩权重flatten_parameters()

他们在代码警告中明确建议人们拥有一块连续的内存。


推荐阅读