首页 > 解决方案 > Torch 网络负载未正确处理

问题描述

我正在尝试在 pytorch 环境中使用 3x64x64 图像制作网络,看来我成功地训练了我的网络并保存了它。网络看起来像:

class LC_small(nn.Module):
    def __init__(self,c_in,c_out = 256):
    super(LC_small,self).__init__()
        self.conv1 = conv(c_in,64,k=3,stride=1,pad=1)
        self.conv2 = conv(64, 128, k=3, stride=2, pad=1)
        self.conv3 = conv(128, 128, k=3, stride=1, pad=1)
        self.conv4 = conv(128, 128, k=3, stride=2, pad=1)
        self.conv5 = conv(128, 128, k=3, stride=1, pad=1)
        self.conv6 = conv(128, 256, k=3, stride=2, pad=1)
        self.conv7 = conv(256, 256, k=3, stride=1, pad=1)# int(h/8 x w/8 x 256)
        self.flat = dense(int(w_rsz/8)*int(h_rsz/8)*256,256)
        self.dense1 = dense(256,128,False)
        self.dense2 = dense(128,3,False)
    def forward(self, input):
        out = self.conv1(input)
        out = self.conv2(out)
        out = self.conv3(out)
        out = self.conv4(out)
        out = self.conv5(out)
        out = self.conv6(out)
        out = self.conv7(out)
        out = out.view(out.size(0),-1)
        out = self.flat(out)
        out = self.dense1(out)
        out = self.dense2(out)
         # print(out.shape)
        normal = torch.nn.functional.normalize(out, 2, 1)

        return normal

我在训练时保存了我的模型:

for epoch in range(10):
#  continue    # 현재 Training 됐다고 가정하고
    total_loss = 0
    route_param = open(route_diffuse+'/netparam.txt','w')
    for param in lcnet.state_dict():
    route_param.write(str(param)+'\t'+str(lcnet.state_dict()[param].size())+'\n')
    for i,data in enumerate(load_LC,0):
    input, gtval = data[0].to(dev),data[1].to(dev)
    opt.zero_grad()

    output = lcnet(input)
    loss = crit(output,gtval)
    loss.backward()
    opt.step()
    total_loss +=loss.item()
    if i%10 == 9:
         print(epoch,i,total_loss/10)
         torch.save(lcnet,route_save)
         total_loss = 0

但是,当我尝试加载我创建的网络时,我看到了如下错误消息:

Traceback (most recent call last):

File "E:/DLPrj/venv/torch_practice.py", line 324, in <module>

ipl,npl = getseqi_np(sq_t,lcnet)   #  data : 8 x 6 x w x h 

File "E:/DLPrj/venv/torch_practice.py", line 133, in getseqi_np

l1 = net_lc(torch.from_numpy(i1r))

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:/DLPrj/venv/torch_practice.py", line 216, in forward

out = self.conv1(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\container.py", line 92, in forward

input = module(input)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\module.py", line 541, in __call__

result = self.forward(*input, **kwargs)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 345, in forward

return self.conv2d_forward(input, self.weight)

File "E:\DLPrj\venv\lib\site-packages\torch\nn\modules\conv.py", line 342, in conv2d_forward

self.padding, self.dilation, self.groups)

RuntimeError: Expected 4-dimensional input for 4-dimensional weight 64 3 3 3, but got 3-dimensional input of size [64, 64, 3] instead

在此错误 pycharm 冻结后,我无法重新运行此代码,直到我重新启动 pycharm。

当我训练我的网络时,我也会收到一些警告消息:

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LC_small. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Sequential. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Conv2d. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type BatchNorm2d. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type LeakyReLU. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

E:\DLPrj\venv\lib\site-packages\torch\serialization.py:292: UserWarning: Couldn't retrieve source code for container of type Linear. It won't be checked for correctness upon loading.

 "type " + obj.__name__ + ". It won't be checked "

我不明白为什么网络的输入大小会突然改变,或者为什么它错误地保存了我的网络。请检查我的问题,非常感谢。

标签: pythonpytorch

解决方案


所以你的第一条错误信息是因为 torch.from_numpy(i1r) 的形状错误。你需要做

np.expand_dims(i1r.transpose(2,0,1), axis=0) 

然后它将得到正确处理。这是因为它需要一个批处理维度,并且您没有提供一个以及通道位于第一个维度而不是最后一个维度。

至于您的第二条错误消息,这可能是因为您错误地定义了 conv,并且很密集,因此在保存模型时它搞砸了。


推荐阅读