keras - 在 keras 中制作“非全连接”(单连接?)神经网络
问题描述
我不知道我要查找的名称,但我想在 keras 中创建一个层,其中每个输入都乘以它自己的独立权重和偏差。例如,如果有 10 个输入,则将有 10 个权重和 10 个偏差,每个输入将乘以它的权重并与它的偏差相加得到 10 个输出。
例如这里是一个简单的密集网络:
from keras.layers import Input, Dense
from keras.models import Model
N = 10
input = Input((N,))
output = Dense(N)(input)
model = Model(input, output)
model.summary()
可以看到,这个模型有 110 个参数,因为它是全连接的:
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_2 (InputLayer) (None, 10) 0
_________________________________________________________________
dense_2 (Dense) (None, 10) 110
=================================================================
Total params: 110
Trainable params: 110
Non-trainable params: 0
_________________________________________________________________
我想output = Dense(N)(input)
用类似的东西替换output = SinglyConnected()(input)
,这样模型现在有 20 个参数:10 个权重和 10 个偏差。
解决方案
创建自定义层:
class SingleConnected(Layer):
#creator
def __init__(self, **kwargs):
super(SingleConnected, self).__init__(**kwargs)
#creates weights
def build(self, input_shape):
weight_shape = (1,) * (len(input_shape) - 1)
weight_shape = weight_shape + (input_shape[-1]) #(....., input)
self.kernel = self.add_weight(name='kernel',
shape=weight_shape,
initializer='uniform',
trainable=True)
self.bias = self.add_weight(name='bias',
shape=weight_shape,
initializer='zeros',
trainable=True)
self.built=True
#operation:
def call(self, inputs):
return (inputs * self.kernel) + self.bias
#output shape
def compute_output_shape(self, input_shape):
return input_shape
#for saving the model - only necessary if you have parameters in __init__
def get_config(self):
config = super(SingleConnected, self).get_config()
return config
使用图层:
model.add(SingleConnected())
推荐阅读
- python - 通过 docker 连接到 smtp 服务器
- docker - Google Artifact Registry 的 docker 服务器 url 是什么?
- c# - 禁用我桌面的 windows 键,但不是 C# 的形式
- android - 类路径中的运行时 JAR 文件的版本为 1.4,比 API 版本 1.5 旧?
- scala - sbt 项目中的 Java Actor
- macos - git 说在我的 cd 中单击现有文件时会创建一个二进制文件
- html - 奇怪的网页渲染
- php - 不能用??在刀片模板中,Laravel
- python - 有没有办法用 python 对高于值和系列阈值(连续数字高于阈值)的数据点进行分类?
- apache-flink - Flink:在DataStream API的批处理模式下左连接相当于Dataset API?