python - 如何移除模型 Tensorflow Keras 中的特定神经元
问题描述
有没有办法去除模型中的特定神经元?
例如,我有一个具有 512 个神经元的 Dense 层的模型。有没有办法去除所有内部有索引的神经元list_indeces
?当然,移除一个神经元会影响下一层,甚至是前一层。
例子:
我在多篇论文中都有这个通用模型:
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = functools.partial(
tf.keras.layers.MaxPooling2D,
pool_size=(2, 2),
padding='same',
data_format=data_format)
conv2d = functools.partial(
tf.keras.layers.Conv2D,
kernel_size=5,
padding='same',
data_format=data_format,
activation=tf.nn.relu)
model = tf.keras.models.Sequential([
conv2d(filters=32, input_shape=input_shape),
max_pool(),
conv2d(filters=64),
max_pool(),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(512, activation=tf.nn.relu),
tf.keras.layers.Dense(10 if only_digits else 62),
])
return model
假设tf.keras.layers.Dense(512, activation=tf.nn.relu)
我想从层中移除 100 个神经元,基本上将它们关闭。
当然,我将有一个新模型,tf.keras.layers.Dense(412, activation=tf.nn.relu)
而不是该层,tf.keras.layers.Dense(512, activation=tf.nn.relu)
但是这种修改也应该传播到下一层的权重,因为从密集层的神经元到下一层的连接也被删除了。
关于如何做到这一点的任何意见?我可以通过执行以下操作手动执行此操作:
如果我得到正确的模型形状是这个:[5, 5, 1, 32], [32], [5, 5, 32, 64], [64], [3136, 512], [512], [512, 62], [62]
所以我可以做这样的事情:
- 生成我需要的所有索引并在里面相同
list_indices
- 访问层的权重
tf.keras.layers.Dense(512, activation=tf.nn.relu)
,并创建一个包含所有权重的张量list_indices
- 将新的权重张量分配给
tf.keras.layers.Dense(412, activation=tf.nn.relu)
子模型的层
问题是我不知道如何获得下一层权重的正确权重,这些权重与我刚刚创建的权重的索引以及我应该分配给子模型的下一层的权重相对应。我希望我已经清楚地解释了自己。
谢谢,莱拉。
解决方案
您的操作在文献中被称为selective dropout
,实际上不需要每次都创建不同的模型,您只需将所选神经元的输出乘以 0,这样下一层的输入就不会接受这些激活帐户。
请注意,如果您“关闭”该层中的一个神经元,Ln
它不会完全“关闭”该层中的任何神经元Ln+1
,假设两者都是全连接层(密集):该Ln+1
层中的每个神经元都连接到所有神经元在上一层。换句话说,在全连接(密集)层中移除一个神经元不会影响下一层的维度。
Multiply Layer
您可以使用(Keras)简单地实现此操作。缺点是需要学习如何使用Keras 函数式 API。还有其他方法,但比这更复杂(例如自定义层),功能性API在许多方面都非常有用和强大,非常建议阅读!
你的模型会变成这样:
data_format = 'channels_last'
input_shape = [28, 28, 1]
max_pool = ...
conv2d = ...
# convert a list of indexes to a weight tensor
def make_index_weights(indexes):
# converting indexes to a list of weights
indexes = [ float(i not in indexes) for i in range(units) ]
# converting indexes from list/numpy to tensor
indexes = tf.convert_to_tensor(indexes)
# reshaping to the correct format
indexes = tf.reshape(indexes, (1, units))
# ensuring it is a float tensor
indexes = tf.cast(indexes, 'float32')
return indexes
# layer builder utility
def selective_dropout(units, indexes, **kwargs):
indexes = make_index_weights(indexes)
dense = tf.keras.layers.Dense(units, **kwargs)
mul = tf.keras.layers.Multiply()
# return the tensor builder
return lambda inputs: mul([ dense(inputs), indexes ])
input_layer = tf.keras.layers.Input(input_shape)
conv_1 = conv2d(filters=32, input_shape=input_shape)(input_layer)
maxp_1 = max_pool()(conv_1)
conv_2 = conv2d(filters=64)(maxp_1)
maxp_2 = max_pool()(conv_2)
flat = tf.keras.layers.Flatten()(maxp_2)
sel_drop_1 = selective_dropout(512, INDEXES, activation=tf.nn.relu)(flat)
dense_2 = tf.keras.layers.Dense(10 if only_digits else 62)(sel_drop_1)
output_layer = dense2
model = tf.keras.models.Model([ input_layer ], [ output_layer ])
return model
INDEXES
现在你只需要根据你需要删除的那些神经元的索引来建立你的列表。
在您的情况下,张量的形状为 ,1x512
因为密集层中有 512 个权重(单位/神经元),因此您需要为索引提供尽可能多的权重。该selective_dropout
函数允许传递要丢弃的索引列表,并自动建立所需的张量。
例如,如果您想删除神经元 1、10、12,您只需将列表传递[1, 10, 12]
给函数,它将在这些位置以及所有其他位置产生一个1x512
张量。0.0
1.0
编辑:
正如您所提到的,您严格需要减少模型中参数的大小。
每个密集层由关系描述y = Wx + B
,其中W
是内核(或权重矩阵)并且B
是偏置向量。W
是INPUTxOUTPUT
维度矩阵,其中INPUT
是最后一层输出形状,OUTPUT
是层中神经元/单元/权重的数量;B
只是一个维度向量1xOUTPUT
(但我们对此不感兴趣)。
现在的问题是你N
在层中丢弃神经元,这会导致层中权重Ln
的下降。让我们实践一些数字。在您的情况下(假设为真),您从以下开始:NxOUTPUT
Ln+1
only_digits
Nx512 -> 512x10 (5120 weights)
并且在丢掉 100 个神经元之后(意味着丢掉 100*10=1000 个权重)
Nx412 -> 412x10 (4120 weights)
现在矩阵的每一列都W
描述了一个神经元(作为权重向量,其维度等于前一层输出维度,在我们的例子中为 512 或 412)。矩阵的行代表前一层中的单个神经元。
表示 layer 的第一个神经元和 layer的第一个神经元之间的W[0,0]
关系。n
n+1
W[0,0] -> 1st n, 1st n+1
W[0,1] -> 2nd n, 1st n+1
W[1,0] -> 1st n, 2nd n+1
等等。所以你可以从这个矩阵中删除所有与你删除的神经元索引相关的行:index 0 -> row 0
。
您可以使用W
密集层将矩阵作为张量访问dense.kernel
推荐阅读
- websocket - WebSocket:失败:WebSocket 握手期间出错:net::ERR_INVALID_HTTP_RESPONSE
- node.js - React Native 使用 fetch 向 Node.js 服务器发送数据
- jhipster - 带有 Easticsearch 的 v5.0.1 的新 JHipster 应用程序生成(HTTP 代码 500)服务器错误
- webpack - 将 Sass 与 Vue-loader webpack 堆栈正确集成
- html - CSS页脚网格响应
- angularjs - AngularJS $http.post 不向 Django REST 后端发送数据
- java - 重启后Hbase数据被删除
- azure-sql-database - 无法导出 Azure SQL 数据库
- r - 直方图中的条数 - R
- python - 从另一个字典中删除重复值