首页 > 解决方案 > 如何在pytorch中包含一个属性让用户决定是否使用GPU?

问题描述

我是 Python 和 pytorch 的初学者,我很抱歉没有让自己清楚或没有使用正确的术语。

我正在尝试编写一个线性回归类,它允许用户决定是否要在 GPU 上运行模型。

以下是我的代码非常愚蠢。

class LinearRegression():
    def __init__(self, dtype = torch.float64):  
        self.dtype = dtype
        self.bias = None
        self.weight = None
        
    def fit(self, x, y, std = None, device = "cuda"): 
        if x.dtype is not self.dtype:
            x = x.type(dtype = self.dtype)
        if y.dtype is not self.dtype:
            y = y.type(dtype = self.dtype)

        #let the user decide whether they want it to be standardized 
        if std == True:
          mean = torch.mean(input = x, dim=0).to(device)
          sd = torch.std(input = x, dim=0).to(device)
          x = x.to(device)
          x = (x-mean)/sd
        u = torch.ones(size = (x.size()[0], 1), dtype = self.dtype).to(device)
        x_design = torch.cat([u, x], dim = 1).to(device)
        y = y.to(device)
        parameter = torch.inverse(
            torch.transpose(x_design, dim0=0, dim1=1) @ x_design).to(device) @ \
                    torch.transpose(x_design, dim0=0, dim1=1).to(device) @ y

        self.bias = parameter[0, 0]
        self.weight = parameter[1:, 0]
        return self

基本上我只是将 .to(device) 添加到每个张量,以便它可以在用户希望使用的设备上运行。但是,我确信有更好的方法可以做到这一点,也许包括在__int__?

我不确定如何更有效地编写它,以便在添加新功能时不必包含 .to(device) 。

标签: pythonclasspytorchgpu

解决方案


推荐阅读