python - 在 Transformer 模型中删除层 (PyTorch / HuggingFace)
问题描述
我偶然发现了一篇关于 Transformer 模型中层丢失的有趣论文,我实际上正在尝试实现它。但是,我想知道执行“层删除”的好习惯是什么。
我有几个想法,但不知道去这里最干净/最安全的方式是什么:
- 掩盖不需要的层(某种修剪)
- 将想要的图层复制到新模型中
如果有人之前已经这样做过或有建议,我会全力以赴!
干杯
解决方案
我认为最安全的方法之一就是简单地跳过前向传递中的给定层。
例如,假设您正在使用BERT
并且您将以下条目添加到配置中:
config.active_layers = [False, True] * 6 # using a 12 layers model
然后你可以BertEncoder
像下面这样修改类:
class BertEncoder(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
self.layer = nn.ModuleList([BertLayer(config) for _ in range(config.num_hidden_layers)])
self.gradient_checkpointing = False
def forward(
self,
hidden_states,
attention_mask=None,
head_mask=None,
encoder_hidden_states=None,
encoder_attention_mask=None,
past_key_values=None,
use_cache=None,
output_attentions=False,
output_hidden_states=False,
return_dict=True,
):
all_hidden_states = () if output_hidden_states else None
all_self_attentions = () if output_attentions else None
all_cross_attentions = () if output_attentions and self.config.add_cross_attention else None
next_decoder_cache = () if use_cache else None
for i, layer_module in enumerate(self.layer):
########### MAGIC HERE #############
if not self.config.active_layers[i]:
continue
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
layer_head_mask = head_mask[i] if head_mask is not None else None
past_key_value = past_key_values[i] if past_key_values is not None else None
if self.gradient_checkpointing and self.training:
if use_cache:
logger.warning(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache = False
def create_custom_forward(module):
def custom_forward(*inputs):
return module(*inputs, past_key_value, output_attentions)
return custom_forward
layer_outputs = torch.utils.checkpoint.checkpoint(
create_custom_forward(layer_module),
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
)
else:
layer_outputs = layer_module(
hidden_states,
attention_mask,
layer_head_mask,
encoder_hidden_states,
encoder_attention_mask,
past_key_value,
output_attentions,
)
hidden_states = layer_outputs[0]
if use_cache:
next_decoder_cache += (layer_outputs[-1],)
if output_attentions:
all_self_attentions = all_self_attentions + (layer_outputs[1],)
if self.config.add_cross_attention:
all_cross_attentions = all_cross_attentions + (layer_outputs[2],)
if output_hidden_states:
all_hidden_states = all_hidden_states + (hidden_states,)
if not return_dict:
return tuple(
v
for v in [
hidden_states,
next_decoder_cache,
all_hidden_states,
all_self_attentions,
all_cross_attentions,
]
if v is not None
)
return BaseModelOutputWithPastAndCrossAttentions(
last_hidden_state=hidden_states,
past_key_values=next_decoder_cache,
hidden_states=all_hidden_states,
attentions=all_self_attentions,
cross_attentions=all_cross_attentions,
)
目前,您可能需要BERT
使用新Encoder
层编写您的特殊类。但是,您应该能够从提供的预训练模型中加载权重huggingface
。
BertEncoder
从这里获取的代码
推荐阅读
- d3.js - d3 US choropleth 在边缘映射 3-d 视图
- swift - 在 VStack 中使用列表时获取“表达式类型在没有上下文的情况下不明确”
- c# - 将属性包裹在自定义属性中
- typescript - 路由保护vue3中的参数(路由,来自,下一个)类型
- windows - 将带有字符串键和 PSObject 值的复杂 Hashtable 导出到 .csv [PowerShell]
- java - 通过调用 appendTo 更改 Jsoup 元素
- javascript - 防止节点“fs”在换行符中插入“\r”
- google-cloud-platform - 无法再通过 SSH 连接到我的 Google Cloud VM 实例
- mongodb - 如何为存储混合内容的 MongoDB 集合设置排序规则?
- python - 导入的函数和类之间是否存在性能差异?