python - 如何指定神经网络层中的连接(在 keras 中)?
问题描述
我想指定两层之间的连接。我有输入、权重和一个稀疏定义的矩阵,
假设这些是我的数据:
import numpy as np
import random
from tensorflow.keras.utils import to_categorical
X = np.random.rand(500,100)
y = []
for i in range(0,500):
n = random.randint(0,1)
y.append(n)
y_enc = to_categorical(y)
go =np.random.randint(2, size=(100, 50))
所以我创建了一个线性层,它正是这样做的,它需要一个输入张量,创建权重并将权重与稀疏矩阵相乘。
from tensorflow.keras import activations, constraints
class LinearGO(keras.layers.Layer):
def __init__(self, units=3, input_dim=3, zeroes = None,
activation=None, **kwargs):
super(LinearGO, self).__init__( **kwargs)
self.units = units
self.zeroes = zeroes
self.unit_number = self.zeroes.shape[-1]
self.activation = activation
self.activation_fn = activations.get(activation)
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.unit_number ),
initializer="glorot_uniform",
trainable=True,
)
self.b = self.add_weight(
shape=(self.unit_number,), initializer="glorot_uniform", trainable=True
)
self.sparse_mat = tf.convert_to_tensor(self.zeroes , dtype=tf.float32)
#self.w = tf.multiply(self.w, self.sparse_mat)
def call(self, inputs):
output = tf.matmul(inputs, self.kernel * self.sparse_mat ) + self.b
if self.activation_fn is not None:
output = self.activation_fn(output)
return output
def compute_output_shape(self, input_shape):
return (input_shape[0], self.unit_number)
def get_config(self):
config = super(Linear, self).get_config()
config.update({"units": self.units})
return config
#example of usage
linear_layer = LinearGO( zeroes = go)
y = linear_layer(X)
print(y)
我可以在 keras 模型中使用,我可以编译和拟合模型。但是当然效果不好,准确性很差。因为也许这些权重为零,但它们仍然参与反向传播,并且由于我的矩阵非常稀疏,所以它是一个问题。
在这个答案中,他们对卷积层做了类似的事情,他们停止了遮罩的渐变,我想在我的自定义密集层中实现类似的东西:
解决方案
推荐阅读
- google-cloud-platform - 在哪里可以查看 PubSub 服务帐号?
- sql - 字段减法sql server
- python - 为什么会弹出 KeyError 以及如何防止它
- macros - 是否可以在表达式中执行代码?
- javascript - 给定一个函数,如何确定它的定义位置
- selenium - selenium grid 4 连接远程驱动程序比 3 需要更长的时间?
- python - 捕获由于 Ruby 中的括号/括号/大括号不匹配而导致的异常
- python - 如何在 python 中播放 GIF 文件?
- pytorch - 张量重塑查询
- c# - 比较 C# .NET 中列表中的元素是否几乎相等