python - 如何为层中的每个节点为 Keras relu 函数分配自定义 alpha?
问题描述
我想为每个 Keras 激活函数添加一个特定于节点的变量。我希望每个节点都用不同的值( )计算激活值(输出alpha
)。
这可以全局完成,例如使用alpha
relu 激活函数的参数(链接):
# Build Model
...
model.add(Dense(units=128))
model.add(Activation(lambda x: custom_activation(x, alpha=0.1)))
...
我也可以写一个自定义的激活函数,但是alpha
参数也是全局的。(链接)
# Custom activation function
def custom_activation(x, alpha=0.0):
return (K.sigmoid(x + alpha))
# Build Model
...
model.add(Dense(units=128))
model.add(Activation(lambda x: custom_activation(x, alpha=0.1)))
...
在自定义函数中,我目前只能访问以下变量:
(Pdb) locals()
{'x': <tf.Tensor 'dense/Identity:0' shape=(None, 128) dtype=float32>, 'alpha': 0.1}
我想使用自定义激活函数,但对于alpha
网络中的每个节点都是唯一的。例如,如果层中有 128 个单元,那么我希望也有 128 个 alpha 值,每个单元/节点一个。然后我希望激活函数
如何创建alpha
一个层中每个单元/节点唯一的值?
解决方案
我不建议为那个使用 lambda 层,它是 hackish。我建议您编写自己的图层,如下所示:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
# Custom layer
class CustomAct(tf.keras.layers.Layer):
def __init__(self):
super(CustomAct, self).__init__()
def build(self, input_shape):
self.alpha = self.add_weight(name='alpha',
shape=[input_shape[1], ],
initializer='uniform',
trainable=True)
super(CustomAct, self).build(input_shape)
def call(self, x):
return tf.sigmoid(x+self.alpha)
def get_alpha(self):
return self.alpha
inputs = np.random.random([16, 32]).astype(np.float32)
# Model
model = tf.keras.models.Sequential()
model.add(tf.keras.Input(inputs.shape[-1]))
model.add(tf.keras.layers.Dense(128))
model.add(CustomAct())
# Test
model.compile(loss="MSE")
alpha_after_initialization = model.layers[-1].get_alpha()
plt.plot(alpha_after_initialization.numpy())
x = np.random.random([18, 32])
y = np.random.random([18, 128])
for _ in range(20):
model.fit(x, y)
out_after_20_steps = alpha_after_initialization = model.layers[-1].get_alpha()
plt.plot(alpha_after_initialization.numpy())
plt.show()
当然,您应该将所有 tf 引用更改为您的 keras 引用。
推荐阅读
- python - 有没有比python中的networkx更有效的方法来计算最短路径问题?
- cucumber - 用于导入测试执行结果的 Xray Rest API 调用错误
- apache-spark - 有没有办法用 apache flink 读取镶木地板文件?
- python - 使用 Numba 进行调试
- python - python:当字典的值是列表时,如何检查字典的任何元素中是否存在值?
- http - Safari 14 如何在 HTTP 与 HTTPS 中处理 iframe cookie 的区别
- powerbi - Power BI - countif / averageif 与 dax
- python - Python:定义函数以包含属性列表作为变量
- python-3.x - 如何在列表中的列表中拆分单个字符?
- html - 调整窗口大小时,Div 移开