首页 > 解决方案 > Python中的并行化

问题描述

我有一些我想在 python 中并行做的事情。以下功能:

def _train_users_locally(self):
        for i in self.users:
            self.users[i].train()

获取一组 'user' 类的实例来训练 pytorch 神经网络:

def train(self):
        self.current_model.train()
        optim = self.optim(self.current_model.parameters(), lr=self.lr)
        criterion = self.criterion()
        for epoch in range(self.epochs):
            loss_per_epoch = 0
            counter = 0
            for i, data in enumerate(self.dataloader):
                x, y = data
                fx = self.current_model(x.unsqueeze(1))
                loss = criterion(fx, y)
                optim.zero_grad()
                loss.backward(retain_graph=True)
                optim.step()
                loss_per_epoch += loss.item()
                print('\rEpoch {}\tBatch: {:.3f}, Loss: {:.3f}'.format(epoch+1, i, loss.item()), end="")
                counter += 1
            print('\nEpoch {}\t Average Loss: {:.3f}'.format(epoch+1, loss_per_epoch / counter))

什么都没有输出;分配给对象的模型正在更新。我希望每个用户对象同时进行训练,但我一生都无法弄清楚如何做到这一点,因为我能够找到的所有示例都涉及对列表元素的一些处理。

标签: pythonparallel-processingpytorch

解决方案


想出来了,谢谢高塔姆。

from threading import Thread

def _train_users_locally(self):
        threads = []
        for i in self.users:
            t = Thread(target=self.users[i].train)
            threads.append(t)
        for t in threads:
            t.start()
        for t in threads:
            t.join()

它似乎按预期工作;但我不确定我是否会遗漏一些会意外弹出的可怕东西。


推荐阅读