tensorflow - 在 TensorFlow 2.4.1 中将 PRelu 激活实现为函数
问题描述
我正在尝试在 tensorflow 2.4.1 中实现PReLU激活,如此处给出的如何在 Tensorflow 中实现 PReLU 激活?
出现以下错误
ValueError: Variable alpha already exists, disallowed. Did you mean to set reuse=True or reuse=tf.AUTO_REUSE in VarScope?
这是我的代码
def prelu(_x):
alphas = tf.compat.v1.get_variable('alpha', _x.get_shape()[-1],
initializer=tf.constant_initializer(0.0), dtype=tf.float32)
pos = tf.nn.relu(_x)
neg = alphas * (_x - abs(_x)) * 0.5
return pos + neg
任何帮助表示赞赏。
笔记 :
我不想使用层接口 tf.keras.layers.PReLU,因为它不能作为参数传递给 conv2D,如下所示
Conv2D(filters, (3, 3), padding='same', activation='prelu')
解决方案
这是tensorflow作为函数而不是层PRelu
的实现,它可作为内置激活层和(我认为应该使用)PRelu。
def prelu_advanced(scope=None):
def prelu_plus(x):
with tf.compat.v1.variable_scope(name_or_scope=scope,
default_name="prelu", reuse=True):
alpha = tf.compat.v1.get_variable("prelu", shape=x.get_shape()[-1],
dtype=x.dtype, initializer=tf.constant_initializer(0.0))
pos = tf.nn.relu(x)
neg = alpha * (x - abs(x)) * 0.5
return tf.maximum(0.0, x) + alpha * tf.minimum(0.0, x)
return prelu_plus
查看
foo = tf.constant([1.0, -0.1, -1.0, 0.5, 0.5], dtype = tf.float32)
# layer (built-in)
tf.keras.layers.PReLU(alpha_initializer="zeros", alpha_regularizer=None,
alpha_constraint=None, shared_axes=None)(foo).numpy()
array([1. , 0. , 0. , 0.5, 0.5], dtype=float32)
# function
x = prelu_advanced(scope='prelu1')
x(foo).numpy()
array([1. , 0. , 0. , 0.5, 0.5], dtype=float32)
假人训练
input = tf.keras.Input(shape=(28, 28, 1))
x = tf.keras.layers.Flatten()(input)
x = tf.keras.layers.Dense(128, activation=prelu_advanced(scope='prelu'))(x)
y = tf.keras.layers.Dense(units=10, activation='softmax')(x)
func_model = tf.keras.Model(inputs=[input], outputs=[y])
func_model.compile(
loss = tf.keras.losses.CategoricalCrossentropy(),
metrics = tf.keras.metrics.CategoricalAccuracy(),
optimizer = tf.keras.optimizers.Adam())
func_model.fit(x_train, y_train)
4s 2ms/step - loss: 0.2690 - categorical_accuracy: 0.9242
推荐阅读
- javascript - 当只有一个函数应该运行时,两个函数运行
- html - 为什么通过vue组件渲染行时表格样式会移出
- python - 使用 Glue 将数据从 RDS 移动到 S3
- html - 按此顺序显示 .div 或图像
- reactjs - React - 在添加时淡入,在删除项目时淡出
- python - /articles/article/19 处的 ValueError 'article_image' 属性没有与之关联的文件
- php - 我可以使用 PHP 内置网络服务器在 localhost 上运行简单的 CGI 脚本吗?
- java - spring boot oauth ExceptionHandlerExceptionResolver : NestedServletException: 嵌套异常是 java.lang.StackOverflowError]
- java - 多线程访问相同数据但获取最新数据?
- python - 如何将 matplotlib.ConciseDateConverter 与 Seaborn 热图一起使用?