首页 > 解决方案 > torchtext 的 BucketIterator 可以将所有批次填充到相同的长度吗?

问题描述

我最近开始使用 torchtext 来替换我的胶水代码,我遇到了一个问题,我想在我的架构中使用注意力层。为了做到这一点,我需要知道我的训练数据的最大序列长度。

问题在于,torchtext.data.BucketIterator它会按批次进行填充:

# All 4 examples in the batch will be padded to maxlen in the batch
train_iter = torchtext.data.BucketIterator(dataset=train, batch_size=4)

是否有某种方法可以确保所有训练示例都填充到相同的长度;即训练中的maxlen?

标签: python-3.xpytorchpreprocessortorchtext

解决方案


实例化 atorchtext.data.Field时,有一个可选的关键字参数fix_length,当设置该参数时,它定义了所有样本将被填充的长度;默认情况下它没有设置,这意味着灵活的填充。


推荐阅读