python - Pytorch:RuntimeError:mat1 dim 1 必须匹配 mat2 dim 0
问题描述
使用 resnet50 模型。自定义最后一层,它显示运行时错误..我是 PyTorch 的新手,我不断收到错误 mat1 dim1 must match mat1 dim0
这是我的网络代码
from torchvision import models
model = models.resnet50(pretrained=True)
for param in model.parameters():
param.requires_grad = False
class Identity(nn.Module):
def __init__(self):
super(Identity, self).__init__()
def forward(self, x):
return x
model.avgpool = Identity()
model.fc = nn.Linear(2048, 2, bias=True)
for param in model.fc.parameters():
param.requires_grad = True
model = model.to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)
def train(num_epoch, model):
for epoch in range(0, 3):
losses = []
model.train()
loop = tqdm(enumerate(train_loader), total=len(train_loader))
for batch_idx, (data, targets) in loop:
data = data.to(device=device)
targets = targets.to(device=device)
scores = model.forward(data)
loss = criterion(scores, targets)
optimizer.zero_grad()
losses.append(loss)
loss.backward()
optimizer.step()
loop.set_description(f"Epoch {epoch+1}/{num_epoch} process: {int((batch_idx / len(train_loader)) * 100)}")
loop.set_postfix(loss=loss.data.item())
train(1, model)
RuntimeError: mat1 dim 1 必须匹配 mat2 dim 0
解决方案
此错误来自nn.Linear
您更改。
回想一下,nn.Linear
计算一个简单的矩阵点积,因此来自前一层的输入维度必须等于weight
矩阵形状(您将其设置为2048
)。
我的猜测是,由于您删除了该model.avgpool
层,因此您现在有超过 2048 个输入维度,从而导致您得到错误。
顺便说一句,你不需要自己实现“身份”层,pytorch 已经有了nn.Identity
.
推荐阅读
- python-3.x - Python 3 上的 Sqlite,带有 Spatialite 并完全支持空间索引(即 rtree)
- stored-procedures - Firebird Trace SP 的 / 步骤
- php - 如何永久增加 max_input_vars?php.ini 方法会定期被覆盖
- android - 我需要加快从 javaMail.Message 类型获取内容和管理标志的过程
- android - Android App 在 android studio 中不断关闭
- python - 将字符串转换为 false 布尔值
- spring - Spring RestTemplate 似乎不是线程安全的 wrt 标头
- javascript - Ajax 请求中的同步 XMLHttpRequest
- python - Can I do multiple prints on the same line in Python
- vim - 我必须编辑什么文件才能使 vim 自动从某些上下文开始?像 Visual Studio