python - Pytorch,不能在 GPU 上运行 CNN。输入类型(torch.FloatTensor)和权重类型(torch.cuda.FloatTensor)应该相同
问题描述
我正在构建一个简单的图像识别卷积神经网络并尝试在我的 GPU 上运行它,但显然我还没有做任何重要的事情。
我检查了如果 GPU 在开始时和训练中可用,请将批次设置为设备(cuda:0)。
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
# Checks if GPU is available otherwise uses CPU
if torch.cuda.is_available():
device = torch.device("cuda:0")
print("Running on the GPU!")
else:
device = torch.device("cpu")
print("Running on the CPU!")
REBUILD_DATA = False
# Data clean up and format
class DogsVsCats():
IMG_SIZE = 50
CATS = "PetImages/Cat"
DOGS = "PetImages/Dog"
LABELS = {CATS: 0, DOGS: 1}
training_data = []
catcount = 0
dogcount = 0
def make_training_data(self):
for label in self.LABELS:
print(label)
for f in tqdm(os.listdir(label)):
try:
path = os.path.join(label, f)
img = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (self.IMG_SIZE, self.IMG_SIZE))
self.training_data.append([np.array(img), np.eye(2)[self.LABELS[label]] ])
if label == self.CATS:
self.catcount += 1
elif label == self.DOGS:
self.dogcount += 1
except Exception as e:
pass
np.random.shuffle(self.training_data)
np.save("training_data.npy", self.training_data)
print("Cats: ", self.catcount)
print("Dogs: ", self.dogcount)
if REBUILD_DATA:
dogsvcats = DogsVsCats()
dogsvcats.make_training_data()
training_data = np.load("training_data.npy", allow_pickle=True)
# print(len("training_data.npy"))
class Net(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 5)
self.conv2 = nn.Conv2d(32, 64, 5)
self.conv3 = nn.Conv2d(64, 128, 5)
x = torch.randn(50,50).view(-1,1,50,50)
self._to_linear = None
self.convs(x)
self.fc1 = nn.Linear(self._to_linear, 512)
self.fc2 = nn.Linear(512, 2)
def convs(self, x):
x = F.max_pool2d(F.relu(self.conv1(x)), (2,2))
x = F.max_pool2d(F.relu(self.conv2(x)), (2,2))
x = F.max_pool2d(F.relu(self.conv3(x)), (2,2))
print(x[0].shape)
if self._to_linear is None:
self._to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
return x
def forward(self, x):
x = self.convs(x)
x = x.view(-1, self._to_linear)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim = 1)
net = Net().to(device)
optimizer = optim.Adam(net.parameters(), lr = 0.001)
loss_function = nn.MSELoss()
X = torch.Tensor([i[0] for i in training_data]).view(-1, 50, 50)
X = X/255.0
y = torch.Tensor([i[1] for i in training_data])
VAL_PCT = 0.1
val_size = int(len(X)*VAL_PCT)
print(val_size)
train_X = X[:-val_size]
train_y = y[:-val_size]
test_X = X[-val_size:]
test_y = y[-val_size:]
BATCH_SIZE = 100
EPOCHS = 1
def train(net):
for epoch in range(EPOCHS):
for i in tqdm(range(0, len(train_X), BATCH_SIZE)):
# print(i, i+BATCH_SIZE)
batch_X = train_X[i:i+BATCH_SIZE].view(-1,1,50,50).to(device)
batch_y = train_y[i:i+BATCH_SIZE].to(device)
net.zero_grad()
outputs = net(batch_X)
loss = loss_function(outputs, batch_y)
loss.backward()
optimizer.step()
print(loss)
correct = 0
total = 0
with torch.no_grad():
for i in tqdm(range(len(test_X))):
real_class = torch.argmax(test_y[i])
net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
predicted_class = torch.argmax(net_out)
if predicted_class == real_class:
correct += 1
total += 1
print("Accuracy: ", round(correct/total,3))
train(net)
对不起,如果问题太简单了。先感谢您!
解决方案
您应该发布错误的行号,但我认为它来自这个 snipit:
with torch.no_grad():
for i in tqdm(range(len(test_X))):
real_class = torch.argmax(test_y[i])
net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
predicted_class = torch.argmax(net_out)
if predicted_class == real_class:
correct += 1
total += 1
您必须将输入放入您net
必须放入设备中,因此可能会更改行
net_out = net(test_X[i].view(-1, 1, 50, 50))[0]
至
net_out = net(test_X[i].view(-1, 1, 50, 50).to(device)[0]
推荐阅读
- r - 从 R 中的唯一大列表对象创建新文本文件
- hibernate - Hibernate StaleObjectStateException 问题调用合并
- android - android:navigationBarDividerColor 需要 API 级别 28
- javascript - 关闭 Glide.js 轮播上的拖动
- javascript - WebSockets 在本地主机上工作,但在远程 Ubuntu 主机上不工作
- scala - Akka Streams - 了解物化何时以及如何工作
- jwplayer - 如何从 JWPlayer 获取#EXT-X-PROGRAM-DATE-TIME
- zooming - 升级到 v4.6.5 后不显示 Shift-zoom 边框
- java - 如何在 Spring 数据中读取 JSON 字符串并写入数据库 MySql
- elixir - Elixir/Phoenix/Guardian - 在 conn 中分配 current_user 不适用于测试