python - 使用 GPU 使用 PyTorch 进行矩阵分解
问题描述
下面的代码确实可以运行,但它使用 for 循环非常慢。在我的大学,可以使用具有 GPU 资源的服务器。同样,我想了解如何使用批处理来更有效地训练模型。
import torch
import torch.nn as nn
import torch.nn.functional as F
class MatrixFactorization(torch.nn.Module):
def __init__(self, n_items=len(movie_ids), n_factors=300):
super().__init__()
self.vectors = nn.Embedding(n_items, n_factors,sparse=True)
def forward(self, i,j):
feat_i = self.vectors(i)
feat_j = self.vectors(j)
result = (feat_i * feat_j).sum(-1)
return result
model = MatrixFactorization(n_items= len(movie_ids),n_factors=300)
loss_fn = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
epochs = 100
for epoch in range(epochs):
loss = 0
for r,c in zip(r_index, c_index):
i = torch.LongTensor([int(r)])
j = torch.LongTensor([int(c)])
rating = torch.FloatTensor([Xij[i, j]])
# predict
prediction = model(i, j)
loss += loss_fn(prediction, rating)
# Reset the gradients to 0
optimizer.zero_grad()
# backpropagate
loss.backward()
# update weights
optimizer.step()
print(loss)
我尝试了以下更改,但它产生了警告。我不确定为什么我的目标尺寸不匹配,但这似乎是问题的原因。
epochs = 50
for epoch in range(epochs):
loss = 0
# predict
i = torch.LongTensor(r_index)
j = torch.LongTensor(c_index)
ratings = Xij[i, j]
prediction = model(i, j)
loss += loss_fn(prediction, rating)
# Reset the gradients to 0
optimizer.zero_grad()
# backpropagate
loss.backward()
# update weights
optimizer.step()
print(loss)
和警告(不知道我哪里出错了):
/anaconda3/lib/python3.6/site-packages/torch/nn/modules/loss.py:431: UserWarning: Using a target size (torch.Size([1])) that is different to the input size (torch.Size([5931640])). This will likely lead to incorrect results due to broadcasting. Please ensure they have the same size.
return F.mse_loss(input, target, reduction=self.reduction)
解决方案
您的第二个代码段中有错字,
loss += loss_fn(prediction, ratings) # instead of rating
推荐阅读
- opengl - Z 实际上应该是透视除法的什么值?
- html - 在 iOS 和 MacOS 上的 Safari 中,溢出属性工作异常
- git - 来自 gitlab 的新分支不会出现在 git bash 中
- python - 如何创建一个函数,向我显示分类列中具有 0 的行和唯一的数字列中的 ows?
- html - reCAPTCHA 是否可以在没有域名的 html 中工作?
- arrays - 使用 Swift 5.2 解码 JSON 会引发错误“预期解码数组
而是找到了一本字典。” - python - 稳定的 Softmax 函数返回错误的输出
- ocaml - 是否可以在 OCaml 中创建 solib
- c++ - zmq_send() 通过多个连接发送到哪里?
- python - Pygame打开窗口并立即崩溃