python - 如何在神经网络的隐藏层中对权重矩阵的列实施正交约束?
问题描述
我想训练一个对隐藏层权重矩阵的列具有正交约束的神经网络,即学习的权重矩阵应该具有正交列。如何在 keras 中做到这一点?
解决方案
class WeightsOrthogonalityConstraint(tf.keras.constraints.Constraint):
def __init__(self, encoding_dim, weightage = 1.0, axis = 0):
self.encoding_dim = encoding_dim
self.weightage = weightage
self.axis = axis
def weights_orthogonality(self, w):
if(self.axis==1):
w = tf.keras.backend.transpose(w)
if(self.encoding_dim > 1):
m = tf.keras.backend.dot(tf.keras.backend.transpose(w), tf.Variable(w)) - tf.keras.backend.eye(self.encoding_dim)
return self.weightage * tf.keras.backend.sqrt(tf.keras.backend.sum(tf.keras.backend.square(m)))
else:
m = tf.keras.backend.sum(w ** 2) - 1.
return m
def __call__(self, w):
return self.weights_orthogonality(w)
用法:
x = tf.keras.layers.Dense(2, 'relu', input_shape=(4,), kernel_regularizer=WeightsOrthogonalityConstraint(2))
推荐阅读
- r - 循环回归:创建交互项,存储结果,只提取有意义的项
- youtube-api - 生成新的刷新令牌后 YouTube API v3 无效授权
- mysql - 在从 Dockerfile 构建期间部署 mysql
- angularjs - 如何在 chrome for Android App 上进行测试和调试
- powershell - PowerShell替换文件中的值
- firebase - 如何从 Flutter 应用程序在 Cloud Firestore 中记录带有嵌套数组的 json
- java - 面临的问题 - SpringBoot 应用程序中的 Olingo OData 4.0 Java 库集成
- sql - 2 个相关的案例陈述
- datetime - 谷歌表格转换错误的日期
- git - Git - 保持由特定作者完成的提交