python - 如何在pytorch中包含一个属性让用户决定是否使用GPU?
问题描述
我是 Python 和 pytorch 的初学者,我很抱歉没有让自己清楚或没有使用正确的术语。
我正在尝试编写一个线性回归类,它允许用户决定是否要在 GPU 上运行模型。
以下是我的代码非常愚蠢。
class LinearRegression():
def __init__(self, dtype = torch.float64):
self.dtype = dtype
self.bias = None
self.weight = None
def fit(self, x, y, std = None, device = "cuda"):
if x.dtype is not self.dtype:
x = x.type(dtype = self.dtype)
if y.dtype is not self.dtype:
y = y.type(dtype = self.dtype)
#let the user decide whether they want it to be standardized
if std == True:
mean = torch.mean(input = x, dim=0).to(device)
sd = torch.std(input = x, dim=0).to(device)
x = x.to(device)
x = (x-mean)/sd
u = torch.ones(size = (x.size()[0], 1), dtype = self.dtype).to(device)
x_design = torch.cat([u, x], dim = 1).to(device)
y = y.to(device)
parameter = torch.inverse(
torch.transpose(x_design, dim0=0, dim1=1) @ x_design).to(device) @ \
torch.transpose(x_design, dim0=0, dim1=1).to(device) @ y
self.bias = parameter[0, 0]
self.weight = parameter[1:, 0]
return self
基本上我只是将 .to(device) 添加到每个张量,以便它可以在用户希望使用的设备上运行。但是,我确信有更好的方法可以做到这一点,也许包括在__int__
?
我不确定如何更有效地编写它,以便在添加新功能时不必包含 .to(device) 。
解决方案
推荐阅读
- ubuntu-12.04 - Postfix 收不到邮件。Helo 命令 未经授权
- java - 带节点的堆栈的推送方法
- python - “系列”对象没有属性“getformat”
- r - 自相关函数 $r_k$ 在滞后 $k$ 处的线图
- elasticsearch - 如何搜索弹性搜索并按值计算行/结果
- scala - Py4JJavaError:调用 o10695.write 时出错。: java.lang.StackOverflowError
- python - 用于视频处理的 Argparse 语法
- amazon-web-services - 防止用户自行使用联合身份提供商 (FIP) 注册,但如果管理员添加,则允许使用 FIP 登录
- oracle19c - oracle.jdbc.OracleDatabaseException: ORA-01722: 无效号码
- python - Pycharm在放入并从队列中获取后丢失对象指示符