python - Pytorch 卡在训练中
问题描述
我有这个代码:
我的模型.py:
num_workers = 1
classes = ('plane', 'car', 'bird', 'cat',
'deer', 'dog', 'frog', 'horse', 'ship', 'truck')
transform = transforms.Compose(
[transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])
trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
download=True, transform=transform)
trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
shuffle=True, num_workers=num_workers)
testset = torchvision.datasets.CIFAR10(root='./data', train=False,
download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
shuffle=False, num_workers=num_workers)
model_tuils.py:
def train_network(net, number_of_epoch, trainloader, optimizer, criterion):
for epoch in range(number_of_epoch): # loop over the dataset multiple times
running_loss = 0.0
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data
# zero the parameter gradients
optimizer.zero_grad()
# forward + backward + optimize
outputs = net(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
# print statistics
running_loss += loss.item()
if i % 2000 == 1999: # print every 2000 mini-batches
print('[%d, %5d] loss: %.3f' %
(epoch + 1, i + 1, running_loss / 2000))
running_loss = 0.0
print('Finished Training')
当我运行我的代码时,它会在第一次迭代时堆叠在训练循环中,在这一行中:
for i, data in enumerate(trainloader, 0):
# get the inputs; data is a list of [inputs, labels]
inputs, labels = data <----------------------------this line
# zero the parameter gradients
optimizer.zero_grad()
当我试图找出程序的问题时,我会看到这个文件:/Users/user/.pyenv/versions/3.7.8/lib/python3.7/multiprocessing/queues.py:
def _feed(buffer, notempty, send_bytes, writelock, close, ignore_epipe,
onerror, queue_sem):
debug('starting thread to feed data to pipe')
nacquire = notempty.acquire
nrelease = notempty.release
nwait = notempty.wait
bpopleft = buffer.popleft
sentinel = _sentinel
if sys.platform != 'win32':
wacquire = writelock.acquire
wrelease = writelock.release
else:
wacquire = None
while 1:
try:
nacquire()
try:
if not buffer:
nwait() <------------------------------ This line
finally:
nrelease()
我究竟做错了什么?我的 num_workers 是 1 所以它不应该有多个线程
解决方案
推荐阅读
- python - 有人可以解释一下这个 python 管道语法吗
- java - 在事务中多次调用存储过程
- r - 使用 sprint() 分配 sub_ids 以及 character、mutate 和 tbl_vars 中的错误
- symfony - Symfony 5 NelmioDocsBundle 和 JMS Seralizer
- java - 正则表达式检查字符串是否匹配特定模式
- wordpress - 未找到错误页面。主页上的错误 404 wordpress
- azure-data-explorer - 不同的运算符是否在下面使用汇总?
- python - Source installed CKAN 2.9 启动后显示找不到页面
- kubernetes - K8S 部署:如何在删除之前重新创建一个 pod?
- applescript - 如何使用 AppleScript 更快地获取 UI 元素的值?