首页 > 解决方案 > 从 Keras 转换为 Pytorch - conv2d

问题描述

我正在尝试将以下 Keras 代码转换为 PyTorch。

    tf.keras.Sequential([
          Conv2D(128, 1, activation=tf.nn.relu),
          Conv2D(self.channel_n, 1, activation=None),
    ])

使用 self.channels=16 创建模型摘要时,我得到以下摘要。

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
conv2d (Conv2D)              (1, 3, 3, 128)            6272      
_________________________________________________________________
conv2d_1 (Conv2D)            (1, 3, 3, 16)             2064      
=================================================================
Total params: 8,336
Trainable params: 8,336
Non-trainable params: 0

一个人将如何转换?

我已经尝试过这样的:

import torch
from torch import nn

class CellCA(nn.Module):
    def __init__(self, channels, dim=128):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_channels=channels,out_channels=dim, kernel_size=1),
            nn.ReLU(),
            nn.Conv2d(in_channels=dim, out_channels=channels, kernel_size=1),
        )
    def forward(self, x):
        return self.net(x)

但是,我得到 4240 参数

标签: keraspytorch

解决方案


如果您正确配置了初始通道(在本例中为 48),则上述尝试是正确的。


推荐阅读