keras - 具有 cRelu 激活的 Keras 序列模型
问题描述
我在创建具有 3 层的 Dense 模型时遇到问题,其中激活函数是 cRelu。cRelu 连接两个 relu(一个负数和一个正数)并在其输出中创建一个两倍大小的张量。尝试在其后添加另一层时,我总是收到大小不匹配错误
model = Sequential()
model.add(Dense(N, input_dim=K, activation=crelu))
model.add(Dense(N//2, activation=crelu))
如何告诉下一层期望 2N 输入和 N?
解决方案
Keras 不希望激活函数改变输出形状。如果要更改它,则应将 crelu 功能包装在一个层中并指定相应的输出形状:
import tensorflow as tf
from keras.layers import Layer
class cRelu(Layer):
def __init__(self, **kwargs):
super(cRelu, self).__init__(**kwargs)
def build(self, input_shape):
super(cRelu, self).build(input_shape)
def call(self, x):
return tf.nn.crelu(x)
def compute_output_shape(self, input_shape):
"""
All axis of output_shape, except the last one,
coincide with the input shape.
The last one is twice the size of the corresponding input
as it's the axis along which the two relu get concatenated.
"""
return (*input_shape[:-1], input_shape[-1]*2)
然后你可以按如下方式使用它
model = Sequential()
model.add(Dense(N, input_dim=K))
model.add(cRelu())
model.add(Dense(N//2))
model.add(cRelu())
推荐阅读
- android - 始终获取结果代码 RESULT_CANCELED
- php - 面对“serialport.parsers.readline 不是函数”错误-NodeJS
- vb.net - arduino vb.net 溢出 ASCII 和
- python - 使用 Python 在 html 标签中查找标签和 id
- amazon-web-services - 无服务器 lambda 全局环境变量
- gradle - JacocoTestReport 排除 gradle7.1.1 中的文件
- node.js - 如何使用 Mongodb 和 Node.js 删除所有帖子和单个用户
- flutter - type '(dynamic) => Product' 不是类型 '(String, dynamic) => MapEntry 的子类型
'的'转变' - windows - 如何解决“Windows 无法访问指定的设备、路径或文件。您可能没有访问该项目的适当权限。” 错误
- python - Discord bot“RuntimeError:无法关闭正在运行的事件循环”错误