python - Pytorch 上的 1D CNN:mat1 和 mat2 形状不能相乘(10x3 和 10x2)
问题描述
我有一个包含 500 个大小的样本和 2 种标签类型的时间序列,并且想要构建一个带有 pytorch 的 1D CNN:
class Simple1DCNN(torch.nn.Module):
def __init__(self):
super(Simple1DCNN, self).__init__()
self.layer1 = torch.nn.Conv1d(in_channels=50,
out_channels=20,
kernel_size=5,
stride=2)
self.act1 = torch.nn.ReLU()
self.layer2 = torch.nn.Conv1d(in_channels=20,
out_channels=10,
kernel_size=1)
self.fc1 = nn.Linear(10* 1 * 1, 2)
def forward(self, x):
x = x.view(1, 50,-1)
x = self.layer1(x)
x = self.act1(x)
x = self.layer2(x)
x = self.fc1(x)
return x
model = Simple1DCNN()
model(torch.tensor(np.random.uniform(-10, 10, 500)).float())
但收到此错误消息:
Traceback (most recent call last):
File "so_pytorch.py", line 28, in <module>
model(torch.tensor(np.random.uniform(-10, 10, 500)).float())
File "/Users/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "so_pytorch.py", line 23, in forward
x = self.fc1(x)
File "/Users/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/Users/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 93, in forward
return F.linear(input, self.weight, self.bias)
File "/Users/lib/python3.8/site-packages/torch/nn/functional.py", line 1692, in linear
output = input.matmul(weight.t())
RuntimeError: mat1 and mat2 shapes cannot be multiplied (10x3 and 10x2)
我究竟做错了什么?
解决方案
该行的输出x = self.layer2(x)
(也是下一行的输入x = self.fc1(x)
)的形状是torch.Size([1, 10, 3])
。
现在从 的定义来看self.fc1
,它期望输入的最后一个维度是10 * 1 * 1
,10
而您的输入3
因此有错误。
我不知道你想做什么,但假设你想做的是;
- 将整个
500
尺寸序列标记为两个标签之一,您可以这样做。
# replace self.fc1 = nn.Linear(10* 1 * 1, 2) with
self.fc1 = nn.Linear(10 * 3, 2)
# replace x = self.fc1(x) with
x = x.view(1, -1)
x = self.fc1(x)
- 将每个时间步标记
10
为两个标签之一,然后执行此操作。
# replace self.fc1 = nn.Linear(10* 1 * 1, 2) with
self.fc1 = nn.Linear(2, 2)
1的输出形状为(batch size, 2),2的输出形状为(batch size, 10, 2)。
推荐阅读
- python - 使用 Lambda 在 AWS S3 中为 DynamoDB 读取 .csv - TypeError:需要一个类似字节的对象,而不是“str”
- linux - 防止用户通过 SSH 登录以与服务器不同的用户身份登录
- wso2 - WSO2 AM - 发布 API - 由:javax.naming.ConfigurationException 引起
- javascript - 如何使用随机哈希删除事件?
- sapui5 - OData V4 布尔值意外转换为 ODatamodel V4 绑定的“否”/“是”字符串
- bash - 尝试列出主机上的所有用户时出错
- mongodb - MongoDB 领域 - 过滤查询
- python - 在 FastAPI 中使用 DB 依赖项,而无需通过函数树传递它
- javascript - 无法在 javascript 中实现单例模式
- python - 如何让我的代码分析满足特定条件的表格部分?