machine-learning - 使用 PyTorch 训练神经网络时出现错误“‘Softmax’对象没有属性‘log_softmax’”
问题描述
我正在研究 MNIST 数据集的分类器。当我运行下面的代码时,我在 line 处收到错误“'Softmax' object has no attribute 'log_softmax'” loss = loss_function(output, y)
。我还没有找到解决问题的方法。如果您能就如何解决该问题提出建议,我将不胜感激。谢谢你。
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, TensorDataset
import torchvision
import torchvision.transforms as transforms
import numpy as np
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
batch_size = 512
# Image transformations of Torchvision will convert to the images to tensor and normalise with mean and standard deviation
transformer = transforms.Compose([transforms.ToTensor(), transforms.Normalize(mean=(0.1307,), std=(0.3081,))])
data_train = DataLoader(torchvision.datasets.MNIST('Data/data/mnist', download=True, train=True, transform=transformer),
batch_size=batch_size, drop_last=False, shuffle=True)
data_test = DataLoader(torchvision.datasets.MNIST('Data/data/mnist', download=True, train=False, transform=transformer),
batch_size=batch_size, drop_last=False, shuffle=True)
class neural_nw(nn.Module):
def __init__(self):
super(neural_nw, self).__init__()
self.fc1 = nn.Linear(784, 128, True)
self.fc2 = nn.Linear(128, 128, True)
self.fc3 = nn.Linear(128, 10, True)
def forward(self, x):
output = torch.sigmoid(self.fc1(x))
output = torch.sigmoid(self.fc2(output))
output = nn.Softmax(self.fc3(output))
return output
MLP = neural_nw()
loss_function = nn.CrossEntropyLoss()
optimiser = optim.Adam(MLP.parameters(), lr = 0.01)
Epochs = 50
for epoch in range(Epochs):
for X, y in data_train:
X = X.view(X.shape[0], -1)
optimiser.zero_grad()
output = MLP.forward(X)
loss = loss_function(output, y)
loss.backward()
optimiser.step()
解决方案
nn.Softmax
定义一个模块,nn.Modules
定义为 Python 类并具有属性,例如,一个nn.LSTM
模块将具有一些内部属性,例如self.hidden_size
. 另一方面,F.softmax
定义操作并需要传递所有参数(包括权重和偏差)。隐含地,模块通常会在方法中的某处调用它们的功能forward
对应物。
这解释了为什么F.softmax
而不是nn.Softmax
解决您的问题。
推荐阅读
- c# - 日期时间格式字符串 C#
- ios - 无法在路径“var/mobile/container/Data/Application/XXXX/Documents/default.realm”打开领域
- typescript - 记录类型,其中值与定义的接口相同,但为数组
- python - 熊猫:日期时间从日期开始不正确地选择日期作为月份
- slack - 在 Slack 工作区的管理员批准安装应用程序后,如何执行某些操作?
- tableau-api - Tableau 按具有逗号分隔值的列中的多个值进行过滤
- javascript - 你能在 JavaScript 中将 1.0 写成数字吗(小数点后只有零)?
- pine-script - 在交易视图中的多空信号后如何获得 PineScript 中的最高价或最低价?
- javascript - 如何在 jquery 数据表上重新关注此警报 1 消息
- next.js - Next.js - 如何在 404 页面上隐藏导航和页脚组件?