首页 > 解决方案 > Torch.nn 有特定的激活函数吗?

问题描述

我的网络有一个带有 Relu 激活函数的输出层,但我希望输出类似于“Relu+1”,即我希望输出都大于 1 并且具有相同形状的 Relu 函数。我应该如何更改我的 torch.nn 网络?我的代码是这样的:

self.actor = nn.Sequential(
                     nn.Linear(state_dim, 256),
                     nn.ReLU(),
                     nn.Linear(256, 256),
                     nn.ReLU(),
                     nn.Linear(256, action_dim),
                     nn.ReLU()
)

标签: pythonpytorch

解决方案


我能想到的方法有两种。

  1. self.actor一个nn.Module对象
class Actor(nn.Module):
    def __int__(self, state_dim, action_dim):
        super().__init__()
        self.linear1 = nn.Linear(state_dim, 256)
        self.relu = nn.ReLU()
        self.linear2 = nn.Linear(256, 256)
        self.linear3 = nn.Linear(256, action_dim)

    def forward(self, x):
        x = self.linear1(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x)
        x = self.linear2(x)
        x = self.relu(x) + 1
        return x

class ......
    self.actor = Actor(state_dim, action_dim)
  1. 创建一个Module类来执行此操作并将其添加到self.actor
class Add1(nn.Module):
    def forward(self, x):
        return x + 1

class ......
    self.actor = nn.Sequential(
                     nn.Linear(state_dim, 256),
                     nn.ReLU(),
                     nn.Linear(256, 256),
                     nn.ReLU(),
                     nn.Linear(256, action_dim),
                     nn.ReLU(),
                     Add1()
)

推荐阅读