python - 自定义激活函数的可训练参数向量
问题描述
我对 Keras 有点陌生,并且正在尝试使用具有可训练参数的自定义激活函数进行一些实验。我创建了下面的代码,它本质上是 ReLU 激活函数的变体。它目前计算alpha*h1 + (1 - alpha)*h2
whereh1 = relu(x)
和h2 = relu(-x)
,希望帮助处理常规 ReLU 函数可以创建的死神经元。alpha
我想知道是否可以修改此代码以生成可训练参数向量来进一步测试这个想法,而不是仅仅拥有这个可训练参数。任何建议或帮助将不胜感激。
class CustomLayer(Layer):
def __init__(self, alpha, **kwargs):
self.alpha = alpha
super(CustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], self.alpha),
initializer='uniform',
trainable=True)
super(CustomLayer, self).build(input_shape)
def call(self,x):
h1 = K.relu(x)
h2 = K.relu(-x)
return self.kernal*h1 + (1 - self.kernal)*h2
def compute_output_shape(self, input_shape):
return (input_shape[0], self.alpha)
解决方案
有几件事:
- 您当前正在使用 设置网络的输出形状
alpha
,这几乎可以肯定是不正确的。 - 您可以将层的内核定义为对您尝试执行的操作有意义的任何大小。这是您要为激活函数创建可训练参数的地方。
- 由于这是一个激活函数,您可能希望输出与输入具有相同的形状。
尝试类似:
from keras import backend as K
class CustomLayer(Layer):
# You actually don't need to redefine __init__ since we don't need to
# pass any custom parameters. I'm leaving it here to show that it changed
# from your example.
def __init__(self, **kwargs):
super(CustomLayer, self).__init__(**kwargs)
def build(self, input_shape):
self.kernel = self.add_weight(name='kernel',
shape=(input_shape[1], 1),
initializer='uniform',
trainable=True)
super(CustomLayer, self).build(input_shape)
def call(self, x):
h1 = K.relu(x)
h2 = K.relu(-x)
return h1*self.kernal + h2*(1 - self.kernel)
def compute_output_shape(self, input_shape):
return input_shape
我假设您希望alpha
输入向量中的每个特征都有一个不同的参数,这是我在build()
方法中创建的。我也假设input_shape = [batch_size, num_features]
或类似的东西。该方法在and和内核call()
之间执行元素乘法,将两半相加。基本上相同的成本函数,每个特征都有一个独特的。h1
h2
alpha
您可能需要为此进行一些调试,因为我没有运行它的示例。
这是关于编写您自己的图层的文档的链接,您似乎已经拥有,但为了完整起见,我将其包括在此处。
推荐阅读
- swift - 每次按下按钮时打印结果
- arrays - 如何使用 PowerShell 从数组 JSON 中的对象获取键值?
- php - Auth-Attempt Laravel 5.8 只需刷新即可再次登录
- javascript - 如何在 ReactJs 中使用 RadioButton 设置状态
- javascript - 在下拉框中选择出生日期后,我的年龄无法实时显示。为什么?
- csv - 如何在 VS Code 中基于 CSV 文件中的组对行进行着色?
- mysql - ALL子句中的MySQL多列
- r - 以编程方式在 rmardown 中为各种语言创建代码片段
- javascript - Checkbox 在 React todo 应用程序中单击时不会更改其值
- php - 自动将字符串附加到匹配的短语 php