python - TypeError: forward() 接受 2 个位置参数,但 3 个在 pytorch 中给出
问题描述
我的训练循环中有以下错误,我真的不明白问题是什么。我目前正在编写这段代码,所以东西不是最终的,但我无法弄清楚这个问题是什么。
我试过用谷歌搜索错误并阅读了一些答案,但似乎仍然无法理解问题的症结所在。
数据集和数据加载器(X 和 Y 已经给了我,它们都是 [2000, 40, 1] 张量)
class TrainingDataset(data.Dataset):
def __init__(self, X, y):
self.X = X
self.y = y
def __len__(self):
return Nf
# returns corresponding input/output pairs
def __getitem__(self, t):
X = self.X[t]
y = self.y[t]
#print(X.shape, y.shape)
return X, y
# prints torch.Size([2000, 40, 1]) torch.Size([2000, 40, 1])
print(x.size(), y.size())
dataset = TrainingDataset(x,y)
batchSize = 20
dataIter = data.DataLoader(dataset, batchSize)
模型:
class Encoder(nn.Module):
def __init__(self, num_inputs = 40, num_outputs = 40):
super(Encoder, self).__init__()
self.num_inputs = num_inputs
self.num_hidden = num_hidden
self.num_outputs = num_outputs
self.layers = nn.Sequential(
nn.Linear(num_inputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs),
nn.ReLU(),
nn.Linear(num_outputs, num_outputs)
)
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
训练循环:
for epoch in range(epochs):
for batch in dataIter:
optimiser.zero_grad()
l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
l.backward()
optimiser.step()
错误:
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
<ipython-input-15-aa1c60616d82> in <module>()
6 for batch in dataIter:
7 optimiser.zero_grad()
----> 8 l = loss(encoder(x_c=batch[0], y_c=batch[1]), batch[1])
9 l.backward()
10 optimiser.step()
2 frames
/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
887 result = self._slow_forward(*input, **kwargs)
888 else:
--> 889 result = self.forward(*input, **kwargs)
890 for hook in itertools.chain(
891 _global_forward_hooks.values(),
TypeError: forward() takes 2 positional arguments but 3 were given
谁能指出我正确的方向?我刚刚开始学习和做 pytorch,所以我还不擅长这些。
解决方案
def forward(self, x_c, y_c):
return self.layers(x_c, y_c)
你的错误就在这里,这个函数应该只有 1 个参数除了self
.
推荐阅读
- c# - 输入不是有效的 Base64 字符串,因为它包含非 base 64 字符
- java - 我制作了一个 Java 程序但有重复问题
- node.js - npm 更新目录路径混淆
- javascript - Razor - 为 radioButtonFor 赋予“属性”
- c# - 使用 Hotmail SMTP C# 发送电子邮件在本地工作,但不在实时服务器上
- javascript - javascript 只工作一次,在第一次
- c++ - 如何让线程等待条件执行操作而不使用过多的 CPU 时间?
- javascript - WebGL 注释未呈现
- python - YOLO v3 Linux VM OpenCV imtest
- node.js - Sequelize select中的乘法字段