pytorch - flatten_parameters() 有什么作用?
问题描述
我在 RNN 的前向函数中看到了许多使用 flatten_parameters 的 Pytorch 示例
self.rnn.flatten_parameters()
我看到了这个RNNBase,上面写着
重置参数数据指针,以便他们可以使用更快的代码路径
这意味着什么?
解决方案
它可能不是您问题的完整答案。但是,如果你看一下flatten_parameters
的源代码,你会注意到它调用_cudnn_rnn_flatten_weight
了
...
NoGradGuard no_grad;
torch::_cudnn_rnn_flatten_weight(...)
...
是完成这项工作的功能。您会发现它实际上所做的是将模型的权重复制到vector<Tensor>
(检查params_arr
声明)中:
// Slice off views into weight_buf
std::vector<Tensor> params_arr;
size_t params_stride0;
std::tie(params_arr, params_stride0) = get_parameters(handle, rnn, rnn_desc, x_desc, w_desc, weight_buf);
MatrixRef<Tensor> weight{weight_arr, static_cast<size_t>(weight_stride0)},
params{params_arr, params_stride0};
权重复制进来
// Copy weights
_copyParams(weight, params);
另请注意,Reset
它们通过在weights
params
.set_
_
orig_param.set_(new_param.view_as(orig_param));
// Update the storage
for (size_t i = 0; i < weight.size(0); i++) {
for (auto orig_param_it = weight[i].begin(), new_param_it = params[i].begin();
orig_param_it != weight[i].end() && new_param_it != params[i].end();
orig_param_it++, new_param_it++) {
auto orig_param = *orig_param_it, new_param = *new_param_it;
orig_param.set_(new_param.view_as(orig_param));
}
}
©ISO/IECN3092
23.3.6 类模板向量
向量是支持随机访问迭代器的序列容器。此外,它还支持(摊销)恒定时间的最后插入和擦除操作;在中间插入和擦除需要线性时间。存储管理是自动处理的,但可以给出提示以提高效率。向量的元素是连续存储的,这意味着如果
v
是一个向量<T, Allocator>
,其中T
的类型不是 bool,那么它服从identity&v[n] == &v[0] + n
for all0 <= n < v.size()
。
在某些情况下
用户警告:RNN 模块权重不是单个连续内存块的一部分。这意味着它们需要在每次调用时进行压缩,可能会大大增加内存使用量。再次调用压缩权重
flatten_parameters()
。
他们在代码警告中明确建议人们拥有一块连续的内存。
推荐阅读
- sql - 我需要创建一个矩阵来显示一年中每个月的订单数量
- sql - 存储过程不在列中返回值
- node.js - Node.js 高阶函数
- amazon-web-services - 我的负载均衡器无法与我的 Fargate 实例通信
- plugins - 为什么我的插件没有给项目命名?
- excel - 减小 ActiveSheet 粘贴大小。大数据显示内存不足
- java - 修改 yaml 文件中的模式的脚本(或程序)
- java - 如何使用 Android LayoutInflater Factory 注入和覆盖布局属性?
- android - 试图在 android studio (Kotlin) 中用另一个 Fragment 替换一个 Fragment
- php - 如何从 Bundle 中注册新的 Twig 命名空间?