python - 将 pytorch 权重导出到 keras
问题描述
我有一个预训练的 pytorch 模型,我需要在另一个 keras 模型中使用它的权重。
我正在尝试pytorch2keras github 存储库将 pytorch 权重.pth转换为 keras .h5
我的模型如下所示:
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
class BaseNetwork(nn.Module):
def __init__(self, name, channels=1):
super(BaseNetwork, self).__init__()
self._name = name
self._channels = channels
def name(self):
return self._name
def initialize_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.xavier_uniform(m.weight, gain=np.sqrt(2))
if m.bias is not None:
m.bias.data.zero_()
elif isinstance(m, nn.BatchNorm2d):
m.weight.data.fill_(1)
m.bias.data.zero_()
elif isinstance(m, nn.Linear):
m.weight.data.normal_(0, 0.01)
m.bias.data.zero_()
class ImageWiseNetwork(BaseNetwork):
def __init__(self, channels=1):
super(ImageWiseNetwork, self).__init__('iw' + str(channels), channels)
self.features = nn.Sequential(
# Block 1
nn.Conv2d(in_channels=12 * channels, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=2, stride=2),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
# Block 2
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=2, stride=2),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(in_channels=128, out_channels=1, kernel_size=1, stride=1),
)
self.classifier = nn.Sequential(
nn.Linear(1 * 16 * 16, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(128, 128),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(128, 64),
nn.ReLU(inplace=True),
nn.Dropout(0.5, inplace=True),
nn.Linear(64, 4),
)
self.initialize_weights()
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
x = F.log_softmax(x, dim=1)
return x
创建对象模型,model = ImageWiseNetwork()
然后将训练后的权重加载为:
model.load_state_dict(torch.load('Path\to\weights\weights_iw1.pth'))
其次是
input_np = np.random.uniform(0, 1, (3, 12, 512, 512))
input_var = Variable(torch.FloatTensor(input_np))and the
接着
from converter import pytorch_to_keras
k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)
我在回溯中收到以下错误:
> > RuntimeError Traceback (most recent call last) <ipython-input-29-7c5d264109b9> in <module>()
> 1 from converter import pytorch_to_keras
> ----> 2 k_model = pytorch_to_keras(model, input_var, (3, 512, 512,), verbose=True)
>
> ~\Downloads\pytorch2keras-master\pytorch2keras-master\pytorch2keras\converter.py
> in pytorch_to_keras(model, args, input_shape, change_ordering,
> training, verbose)
> 84
> 85 with set_training(model, training):
> ---> 86 trace, torch_out = torch.jit.get_trace_graph(model, args)
> 87
> 88 if orig_state_dict_keys != _unique_state_dict(model).keys():
>
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_trace_graph(f, args, kwargs, nderivs)
> 253 if not isinstance(args, tuple):
> 254 args = (args,)
> --> 255 return LegacyTracedModule(f, nderivs=nderivs)(*args, **kwargs)
> 256
> 257
>
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> __call__(self, *input, **kwargs)
> 487 hook(self, input)
> 488 if torch.jit._tracing:
> --> 489 result = self._slow_forward(*input, **kwargs)
> 490 else:
> 491 result = self.forward(*input, **kwargs)
>
> ~\Anaconda3\lib\site-packages\torch\nn\modules\module.py in
> _slow_forward(self, *input, **kwargs)
> 465 def _slow_forward(self, *input, **kwargs):
> 466 input_vars = tuple(torch.autograd.function._iter_tensors(input))
> --> 467 tracing_state = torch.jit.get_tracing_state(input_vars)
> 468 if not tracing_state:
> 469 return self.forward(*input, **kwargs)
>
> ~\Anaconda3\lib\site-packages\torch\jit\__init__.py in
> get_tracing_state(args)
> 33 if not torch._C._is_tracing(args):
> 34 return None
> ---> 35 return torch._C._get_tracing_state(args)
> 36
> 37
>
> RuntimeError:
> C:\ProgramData\Miniconda3\conda-bld\pytorch_1524549877902\work\torch/csrc/jit/tracer.h:117:
> getTracingState: Assertion `var_state == state` failed.
我无法弄清楚错误。这个错误的可能原因是什么。
解决方案
推荐阅读
- c++ - 将类型分配给变量并将其传递给模板
- javascript - 没有“Access-Control-Allow-Origin”标头存在错误?
- reactjs - 反应未知错误“index.js:63 Uncaught TypeError: Cannot read property 'nodeName' of null”
- google-data-studio - 有没有办法在没有刷新报告的情况下在 Google Data Studio 中每分钟动态显示日期时间?
- javascript - 使用服务器端 Asp 项填充 SweetAlert2 Html 选择输入项
- python - 子进程未创建 ffmpeg 命令的输出文件
- android - 如何解决 Android 上的这个 adb 授予权限错误?
- html - svg上的渐变
元素不工作 - python - python 正在显示在 Windows 中键入 python -v 时要安装的东西的列表
- ios - 快速比较发布时间和当前时间