python - Pytorch 不使用 cuda 设备
问题描述
我有以下代码:
from __future__ import print_function
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import scipy.io
folder = 'small/'
mat = scipy.io.loadmat(folder+'INISTATE.mat');
ini_state = np.float32(mat['ini_state']);
ini_state = torch.from_numpy(ini_state);
ini_state = ini_state.cuda();
mat = scipy.io.loadmat(folder+'TARGET.mat');
target = np.float32(mat['target']);
target = torch.from_numpy(target);
target = target.cuda();
class MLPNet(nn.Module):
def __init__(self):
super(MLPNet, self).__init__()
self.fc1 = nn.Linear(3, 64)
self.fc2 = nn.Linear(64, 128)
self.fc3 = nn.Linear(128, 128)
self.fc4 = nn.Linear(128, 41)
def forward(self, x):
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
x = F.relu(self.fc3(x))
x = self.fc4(x)
return x
def name(self):
return "MLP"
model = MLPNet();
model = model.cuda();
criterion = nn.MSELoss();
criterion = criterion.cuda();
learning_rate = 0.001;
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
batch_size = 20
iter_size = int(target.size(0)/batch_size)
print(iter_size)
for epoch in range(50):
for i in range(iter_size):
start = i*batch_size;
end = (i+1)*batch_size-1;
samples = ini_state[start:end,:];
labels = target[start:end,:];
optimizer.zero_grad() # zero the gradient buffer
outputs = model(samples)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
if (i+1) % 500 == 0:
print("Epoch %s, batch %s, loss %s" % (epoch, i, loss))
if (epoch+1) % 7 == 0:
for g in optimizer.param_groups:
g['lr'] = g['lr']*0.1;
但是当我训练简单的 MLP 时,CPU 使用率大约是 100%,而 gpu 只有大约 10%。阻止使用 GPU 的问题是什么?
解决方案
实际上,您的模型确实在 GPU 而不是 CPU 上运行。GPU使用率低的原因是您的模型和批量大小都很小,这需要较低的计算成本。您可以尝试将批量大小增加到 1000 左右,GPU 使用率应该更高。事实上,PyTorch 阻止了混合 CPU 和 GPU 数据的操作,例如,您不能将 GPU 张量和 CPU 张量相乘。所以通常你的网络的一部分不太可能在 CPU 上运行而另一部分在 GPU 上运行,除非你故意设计它。
顺便说一句,神经网络需要数据混洗。由于您使用的是小批量训练,因此在每次迭代中,您都希望小批量接近整个数据集。如果没有数据混洗,小批量中的样本很可能是高度相关的,这导致参数更新的估计有偏差。PyTorch 提供的数据加载器可以帮助您进行数据混洗。
推荐阅读
- python - 如何通过在单个单元格中删除 NaN 来调整数据框的大小?
- php - 通过 hasManyThrough Laravel Eloquent 获取多个表
- python - 如何使用 python API 将文件嵌入到数据流作业中
- python - Django 管理操作:仅使用一种方法为所有选择生成操作
- android - 启动时重新启动服务 8.1 问题
- android - 房间数据库中的外键
- c# - 如何使用 DisplayFormat 属性格式化表中的值?
- javascript - 在 Firebase 中存储注册数据
- c++ - 如何使用单个返回语句返回对变量的引用
- python - 硒python返回所有输出