nlp - 您如何生成 pytorch bert 预训练神经网络的 ONNX 表示?
问题描述
我正在尝试为pytorch-pretrained-bert run_classifier.py 示例生成一个 ONNX 文件。
在这种情况下,我根据主 README.md 使用以下参数运行它:
export GLUE_DIR=/tmp/glue_data
python run_classifier.py \
--task_name MRPC \
--do_train \
--do_eval \
--do_lower_case \
--data_dir $GLUE_DIR/MRPC/ \
--bert_model bert-base-uncased \
--max_seq_length 128 \
--train_batch_size 32 \
--learning_rate 2e-5 \
--num_train_epochs 3.0 \
--output_dir /tmp/mrpc_output/
在第 552 行修改/添加了以下代码:
# Save a trained model
model_to_save = model.module if hasattr(model, 'module') else model # Only save the model it-self
output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
if args.do_train:
torch.save(model_to_save.state_dict(), output_model_file)
# Save ONNX
msl = args.max_seq_length
dummy_input = torch.randn(1, msl, msl, msl, num_labels, device="cpu")
output_onnx_file = os.path.join(args.output_dir, "classifier.onnx")
torch.onnx.export(model, dummy_input, output_onnx_file)
dummy_input 应该对应于 bert 预训练模型输入。我认为 1 的 sample_batch_size 似乎适合我的需要。
一些人建议参数应该与模型 forward() 方法的参数匹配。在这种情况下:
class BertForSequenceClassification(PreTrainedBertModel):
def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):
在这种情况下,我认为排名是:
input_ids: 1 x 128 <the max_seq_length specified in the args>
token_type_ids: 1 x max_seq_length
attention_mask: 1 x max_seq_length
labels: 1 x 2 <the number of labels for MRPC>
所以有效的调用是:
dummy_input = torch.randn(1, 128, 128, 128, 2, device="cpu")
不幸的是,这会产生一个错误:
Exception has occurred: RuntimeError
The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 4. Target sizes: [1, 128, 128, 128, 2]. Tensor sizes: [1, 128]
这似乎是相当简单的事情。建议赞赏!
解决方案
推荐阅读
- python - 无法使用 python 包 web.py 提供本地静态 CSS 文件,仅 HTML 显示
- vb.net - 将数据源绑定到gridview后如何在Gridview中显示时隐藏列
- visual-studio-code - VS Code 在键入时不显示 Salesforce 对象或类符号。代码完成不起作用
- javascript - 从对象数组中删除按对象属性小于其他对象的元素
- php - PHP AdWords API v201809 - 如何使用 LocationCriterionService 从 GEO_PERFORMANCE_REPORT 获取城市名称?
- sql - 多连接和处理速度
- c# - C# 中 Nuget 包中的 AWS 凭证
- sql - 为 Spark ML 编码转置或旋转一组分类变量的最佳方法
- csv - 是否可以让 gnuplot 忽略具有“N/A”而不是数字的 csv 行?
- spring - 在另一个类中两次获取同一类的实例