python - Python中的并行化
问题描述
我有一些我想在 python 中并行做的事情。以下功能:
def _train_users_locally(self):
for i in self.users:
self.users[i].train()
获取一组 'user' 类的实例来训练 pytorch 神经网络:
def train(self):
self.current_model.train()
optim = self.optim(self.current_model.parameters(), lr=self.lr)
criterion = self.criterion()
for epoch in range(self.epochs):
loss_per_epoch = 0
counter = 0
for i, data in enumerate(self.dataloader):
x, y = data
fx = self.current_model(x.unsqueeze(1))
loss = criterion(fx, y)
optim.zero_grad()
loss.backward(retain_graph=True)
optim.step()
loss_per_epoch += loss.item()
print('\rEpoch {}\tBatch: {:.3f}, Loss: {:.3f}'.format(epoch+1, i, loss.item()), end="")
counter += 1
print('\nEpoch {}\t Average Loss: {:.3f}'.format(epoch+1, loss_per_epoch / counter))
什么都没有输出;分配给对象的模型正在更新。我希望每个用户对象同时进行训练,但我一生都无法弄清楚如何做到这一点,因为我能够找到的所有示例都涉及对列表元素的一些处理。
解决方案
想出来了,谢谢高塔姆。
from threading import Thread
def _train_users_locally(self):
threads = []
for i in self.users:
t = Thread(target=self.users[i].train)
threads.append(t)
for t in threads:
t.start()
for t in threads:
t.join()
它似乎按预期工作;但我不确定我是否会遗漏一些会意外弹出的可怕东西。
推荐阅读
- spring-cloud-dataflow - inheritLogging=true 如何以及如何找到日志?
- scala - 在 Scala 中为 null 时,Option(null) 的 getOrElse 不返回 None 类型或默认值
- android - Galaxy Tab A6 读取的蓝牙 rssi 值不一致
- python - 机器人框架 - 使用参数测试关键字失败
- shell - 剪切文件中的最后一列并在 unix 中创建新文件
- python - 具有大 z 轴的 3d 绘图
- c# - 脚本组件转换行值到变量
- c# - Selenium 找不到元素“a”
- kubectl - 如何设置对 Kubernetes 部署的监视?
- javascript - Threejs围绕Y旋转相机