deep-learning - “网络”对象没有属性“参数”
问题描述
我对机器学习相当陌生。我从 youtube 教程中学会了编写此代码,但我不断收到此错误
Traceback (most recent call last):
File "<input>", line 1, in <module>
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_bundle/pydev_umd.py", line 197, in runfile
pydev_imports.execfile(filename, global_vars, local_vars) # execute the script
File "/Applications/PyCharm.app/Contents/plugins/python/helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
exec(compile(contents+"\n", file, 'exec'), glob, loc)
File "/Users/aniket/Desktop/DeepLearning/PythonLearningPyCharm/CatVsDogs.py", line 109, in <module>
optimizer = optim.Adam(net.parameters(), lr=0.001) # tweaks the weights from what I understand
AttributeError: 'Net' object has no attribute 'parameters'
这是网络类
class Net():
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1,32,5)
self.conv2 = nn.Conv2d(32,64,5)
self.conv3 = nn.Conv2d(64,128,5)
self.to_linear = None
x = torch.randn(50,50).view(-1,1,50,50)
self.Conv2d_Linear_Link(x)
self.fc1 = nn.Linear(self.to_linear, 512)
self.fc2 = nn.Linear(512, 2)
def Conv2d_Linear_Link(self , x):
x = F.max_pool2d(F.relu(self.conv1(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv2(x)),(2,2))
x = F.max_pool2d(F.relu(self.conv3(x)),(2,2))
if self.to_linear is None :
self.to_linear = x[0].shape[0]*x[0].shape[1]*x[0].shape[2]
return x
def forward(self, x):
x = self.Conv2d_Linear_Link(x)
x = F.relu(self.fc1(x))
x = self.fc2(x)
return F.softmax(x, dim=1)
这是功能火车
def train():
for epoch in range(epochs):
for i in tqdm(range(0,len(X_train), batch)):
batch_x = train_X[i:i + batch].view(-1, 1, 50, 50)
batch_y = train_y[i:i + batch]
net.zero_grad() # i don't understand why we do this but we do we don't want the probabilites adding up
output = net(batch_x)
loss = loss_function(output, batch_y)
loss.backward()
optimizer.step()
print(loss)
以及优化器和损失函数和数据
optimizer = optim.Adam(net.parameters(), lr=0.001) # tweaks the weights from what I understand
loss_function = nn.MSELoss() # gives the loss
解决方案
你不是子类化nn.Module
。它应该如下所示:
class Net(nn.Module):
def __init__(self):
super().__init__()
这允许您的网络继承nn.Module
类的所有属性,例如parameters
属性。
推荐阅读
- javascript - Fullcalendar 仅删除特定资源组或资源 ID
- android - 谷歌文档查看器不能在本地主机上工作吗?
- spring-boot - 自定义对象映射器在springboot中不生效
- javascript - ReactJS:已被 CORS 策略阻止:对预检请求的响应未通过访问控制检查
- c# - 无法将字符串“名称”转换为字典键类型 - 创建一个 TypeConverter 以从字符串转换为键类型对象
- kubernetes - Spinnaker 中存储设置的自定义配置文件
- python - How does this while-loop compute? (x = function(x))
- keras - 如何在谷歌 colab 中找到 keras.json 文件?
- node.js - 将 Strapi API 部署到 Plesk
- java - X509TrustManager 的 checkServerTrusted 方法中验证了哪些所有参数?