python - 有没有办法替换 Pytorch 中用于 DDP(DistributedDataParallel) 的“allreduce_hook”?
问题描述
我知道 Pytorch DDP 使用 'allreduce_hook' 作为默认通信挂钩。有没有办法用“quantization_pertensor_hook”或“powerSGD_hook”替换这个默认挂钩。有一个官方的Pytorch 文档介绍了 DDP 通信钩子,但我仍然对如何在实践中做到这一点感到困惑。
这就是我启动流程组并创建 DDP 模型的方式。
import torch.distributed as dist
import torch.nn as nn
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])
有没有办法根据这段代码声明我想要的钩子?
解决方案
这可以完成这项工作
dist.init_process_group(backend='nccl', init_method='env://', world_size=args.world_size, rank=rank)
model = nn.parallel.DistributedDataParallel(model, device_ids=[0])
state = powerSGD.PowerSGDState(process_group=None, matrix_approximation_rank=1, start_powerSGD_iter=10, min_compression_rate=0.5)
model.register_comm_hook(state, powerSGD.powerSGD_hook)
...
推荐阅读
- android - 如何更改 xml 中用户输入的颜色?
- xml - 如何在 xml 文件中添加 n 个节点?
- javascript - 将表数据输入值克隆到整行 onclick 按钮
- c# - ASP.NET Core 和 MediatR:从处理程序发送请求?
- python - 如何在 python 中使用 .format 打印出数据形状
- java - 如何在java中使用OpenHtmlToPdf库下载pdf
- c# - WPF 转换器:返回在另一个文件中创建的 DrawingBrush
- c# - 使用正则表达式替换/删除文件名的特定部分
- c# - 猜数字游戏窗体
- postgresql - Postgres,查询性能取决于前端工具吗?