首页 > 解决方案 > 使用 OpenNMT 进行迁移学习

问题描述

我正在使用 OpenNMT-py 在 MIDI 音乐文件上训练一个变压器模型,但结果很差,因为我只能访问一个与我想研究的风格有关的小数据集。为了帮助模型学习一些有用的东西,我想使用更大的其他音乐风格的数据集进行预训练,然后使用小数据集微调结果。

我想在预训练后冻结变压器的编码器端,让解码器部分自由地进行微调。如何使用 OpenNMT-py 做到这一点?

标签: pythonpytorchtransformertransfer-learningopennmt

解决方案


请更具体地说明您的问题并显示一些代码,这将帮助您从 SO 社区获得富有成效的回应。

如果我在你的位置并想冻结一个神经网络组件,我会简单地做:

for name, param in self.encoder.named_parameters():
    param.requires_grad = False

在这里,我假设您有一个如下所示的 NN 模块。

class Net(nn.Module):
    def __init__(self, params):
        super(Net, self).__init__()

        self.encoder = TransformerEncoder(num_layers,
                                        d_model, 
                                        heads, 
                                        d_ff, 
                                        dropout, 
                                        embeddings,
                                        max_relative_positions)

    def foward(self):
        # write your code

推荐阅读