python - 将模型配置从 allennlp 0.9.0 更新到 1.3.0
问题描述
我正在尝试使用 Udify 库对依赖解析的预训练多语言 BERT 模型进行微调。这个库使用了 allennlp==0.9.0,而我需要使用 allennlp=1.3.0,我正在尝试为此更新代码。在对配置文件进行一些更改后,我在模型的 forward() 方法中遇到了输入问题。即,tokens
被传递,在我的例子中是一个看起来像这样的字典:
tokens = {
"bert": {
"bert": tensor(...),
"bert-offsets": tensor(...),
"bert-type-ids": tensor(...),
"mask": tensor(...)
},
"tokens": {
"tokens": tensor(...)
}
}
此方法中发生错误:
Traceback (most recent call last):
File "train.py", line 74, in <module>
train_model(train_params, serialization_dir, recover=bool(args.resume))
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/commands/train.py", line 236, in train_model
model = _train_worker(
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/commands/train.py", line 466, in _train_worker
metrics = train_loop.run()
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/commands/train.py", line 528, in run
return self.trainer.train()
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/training/trainer.py", line 966, in train
return self._try_train()
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/training/trainer.py", line 1001, in _try_train
train_metrics = self._train_epoch(epoch)
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/training/trainer.py", line 716, in _train_epoch
batch_outputs = self.batch_outputs(batch, for_training=True)
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/allennlp/training/trainer.py", line 604, in batch_outputs
output_dict = self._pytorch_model(**batch)
File "/home/lcur0308/.conda/envs/atcs-project/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl
result = self.forward(*input, **kwargs)
File "/home/lcur0308/atcs-project/multilingual-interference/udify/udify/models/udify_model.py", line 120, in forward
self._apply_token_dropout(tokens)
File "/home/lcur0308/atcs-project/multilingual-interference/udify/udify/models/udify_model.py", line 197, in _apply_token_dropout
tokens["tokens"] = self.token_dropout(
File "/home/lcur0308/atcs-project/multilingual-interference/udify/udify/models/udify_model.py", line 243, in token_dropout
device = tokens.device
AttributeError: 'dict' object has no attribute 'device'
我相信它self.token_dropout()
期望收到一个张量(tokens["tokens"] = tensor(...)
),而相反它收到一个带有张量(tokens["tokens"] = {"tokens":tensor(...)}
)的字典。但是,我不知道如何解决这个问题。当然,我可以破解一个解决方法,比如传递令牌[“令牌”][“令牌”]而不是令牌[“令牌”],但我有一种预感,我的错误是配置中一些潜在错误的副作用和一个快速破解并不能完全解决问题。
经过一些更改以使其与 1.3.0 兼容后,我当前的配置如下所示(对于附加这么长的块,我深表歉意,但我不知道哪个部分是相关的):
{
"dataset_reader": {
"lazy": false,
"token_indexers": {
"tokens": {
"type": "single_id",
"lowercase_tokens": true
},
"bert": {
"type": "udify-bert-pretrained",
"pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt",
"do_lowercase": false,
"use_starting_offsets": true
}
}
},
"train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu",
"validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu",
"test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu",
"vocabulary": {
"type": "from_files",
"directory": "data/concat-exp-mix/vocab/concat-exp-mix/vocabulary/"
},
"model": {
"word_dropout": 0.2,
"mix_embedding": 12,
"layer_dropout": 0.1,
"tasks": ["deps"],
"pretrained_model": "bert-base-multilingual-cased",
"text_field_embedder": {
"type": "udify_embedder",
"dropout": 0.5,
"allow_unmatched_keys": true,
"embedder_to_indexer_map": {
"bert": ["bert", "bert-offsets"]
},
"token_embedders": {
"bert": {
"type": "udify-bert-pretrained",
"pretrained_model": "bert-base-multilingual-cased",
"requires_grad": true,
"dropout": 0.15,
"layer_dropout": 0.1,
"combine_layers": "all"
}
}
},
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"decoders": {
"upos": {
"encoder": {
"type": "pass_through",
"input_dim": 768
}
},
"feats": {
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"adaptive": true
},
"lemmas": {
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"adaptive": true
},
"deps": {
"tag_representation_dim": 256,
"arc_representation_dim": 768,
"encoder": {
"type": "pass_through",
"input_dim": 768
}
}
}
},
"data_loader": {
"batch_sampler":{
"batch_size": 16
}
},
"trainer": {
"num_epochs": 5,
"patience": 40,
"optimizer": {
"type": "adamw",
"betas": [0.9, 0.99],
"weight_decay": 0.01,
"lr": 1e-3,
"parameter_groups": [
[["^text_field_embedder.*.bert_model.embeddings",
"^text_field_embedder.*.bert_model.encoder"], {}],
[["^text_field_embedder.*._scalar_mix",
"^text_field_embedder.*.pooler",
"^scalar_mix",
"^decoders",
"^shared_encoder"], {}]
]
},
"learning_rate_scheduler": {
"type": "ulmfit_sqrt",
"model_size": 1,
"warmup_steps": 392,
"start_step": 392,
"factor": 5.0,
"gradual_unfreezing": true,
"discriminative_fine_tuning": true,
"decay_factor": 0.04
}
},
"udify_replace": [
"dataset_reader.token_indexers",
"model.text_field_embedder",
"model.encoder",
"model.decoders.xpos",
"model.decoders.deps.encoder",
"model.decoders.upos.encoder",
"model.decoders.feats.encoder",
"model.decoders.lemmas.encoder",
"trainer.learning_rate_scheduler",
"trainer.optimizer"
]
}
适用于 allennlp 0.9.0 的相应配置如下:
{
"dataset_reader": {
"lazy": false,
"token_indexers": {
"tokens": {
"type": "single_id",
"lowercase_tokens": true
},
"bert": {
"type": "udify-bert-pretrained",
"pretrained_model": "config/archive/bert-base-multilingual-cased/vocab.txt",
"do_lowercase": false,
"use_starting_offsets": true
}
}
},
"train_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-train.conllu",
"validation_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-dev.conllu",
"test_data_path": "data/ud-treebanks-v2.3/UD_English-EWT/en_ewt-ud-test.conllu",
"vocabulary": {
"directory_path": "data/vocab/english_only_expmix4/vocabulary"
},
"model": {
"word_dropout": 0.2,
"mix_embedding": 12,
"layer_dropout": 0.1,
"tasks": ["deps"],
"text_field_embedder": {
"type": "udify_embedder",
"dropout": 0.5,
"allow_unmatched_keys": true,
"embedder_to_indexer_map": {
"bert": ["bert", "bert-offsets"]
},
"token_embedders": {
"bert": {
"type": "udify-bert-pretrained",
"pretrained_model": "bert-base-multilingual-cased",
"requires_grad": true,
"dropout": 0.15,
"layer_dropout": 0.1,
"combine_layers": "all"
}
}
},
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"decoders": {
"upos": {
"encoder": {
"type": "pass_through",
"input_dim": 768
}
},
"feats": {
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"adaptive": true
},
"lemmas": {
"encoder": {
"type": "pass_through",
"input_dim": 768
},
"adaptive": true
},
"deps": {
"tag_representation_dim": 256,
"arc_representation_dim": 768,
"encoder": {
"type": "pass_through",
"input_dim": 768
}
}
}
},
"iterator": {
"batch_size": 16
},
"trainer": {
"num_epochs": 5,
"patience": 40,
"num_serialized_models_to_keep": 1,
"should_log_learning_rate": true,
"summary_interval": 100,
"optimizer": {
"type": "bert_adam",
"b1": 0.9,
"b2": 0.99,
"weight_decay": 0.01,
"lr": 1e-3,
"parameter_groups": [
[["^text_field_embedder.*.bert_model.embeddings",
"^text_field_embedder.*.bert_model.encoder"], {}],
[["^text_field_embedder.*._scalar_mix",
"^text_field_embedder.*.pooler",
"^scalar_mix",
"^decoders",
"^shared_encoder"], {}]
]
},
"learning_rate_scheduler": {
"type": "ulmfit_sqrt",
"model_size": 1,
"warmup_steps": 392,
"start_step": 392,
"factor": 5.0,
"gradual_unfreezing": true,
"discriminative_fine_tuning": true,
"decay_factor": 0.04
}
},
"udify_replace": [
"dataset_reader.token_indexers",
"model.text_field_embedder",
"model.encoder",
"model.decoders.xpos",
"model.decoders.deps.encoder",
"model.decoders.upos.encoder",
"model.decoders.feats.encoder",
"model.decoders.lemmas.encoder",
"trainer.learning_rate_scheduler",
"trainer.optimizer"
]
}
解决方案
推荐阅读
- javascript - JavaScript 检查字段是否有相同的值,如果是,继续计算总计
- r - read.csv 内容后看起来真的很混乱
- python - 如何将字符串的python列表传递给sql查询
- java - 在递归问题中传递列表的深拷贝和浅拷贝有什么区别?
- java - 按钮操作的条件
- java - 如何从自定义编辑器 hybris 更新/刷新编辑器区域
- networking - 如何从远程机器读取 VLC 中的 UDP 流
- javascript - 在真实 iPhone 上运行调试方案时 Metrobundler 未重新加载
- java - 如何修复我的 while 语句以便仅从输入中获取数字?
- sql - 查询具有不同条件的多个列