python - RuntimeError:堆栈期望每个张量大小相等
问题描述
我提前道歉,这是之前被问到的。我真的不明白解决方案。
MAX_LEN = 160
BATCH_SIZE = 16
EPOCHS = 10
class GPReviewDataset(data.Dataset):
def __init__(self, review, target, tokenizer, max_len):
self.review = review
self.target = target
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.review)
def __getitem__(self, item):
review = str(self.review[item])
encoding = tokenizer.encode_plus(text=review,
max_length=self.max_len,
add_special_tokens=True, padding='max_length',
return_attention_mask=True,
return_token_type_ids=False, return_tensors='pt')
return {'review': review,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(self.target[item], dtype=torch.long)}
free_df_train, free_df_test = train_test_split(free_df, test_size=0.2)
free_df_val, free_df_test = train_test_split(free_df_test, test_size=0.5)
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = GPReviewDataset(review=df.content.to_numpy(),
target=df['score'].to_numpy(),
tokenizer=tokenizer,
max_len=max_len)
return data.DataLoader(ds, batch_size=batch_size,
num_workers=0)
train_data_loader = create_data_loader(free_df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(free_df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(free_df_test, tokenizer, MAX_LEN, BATCH_SIZE)
data = next(iter(train_data_loader))
后来我写了一个函数来训练数据时接受 train_data_loader,但它给了我运行时错误。似乎正确的解决方案是使用某种 collate_fn; 但是我对如何确切地应用该功能感到困惑。
我的错误如下:
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<timed exec> in <module>
<ipython-input-26-8ba1e19dd195> in train_epoch(model, data_loader, loss_fn, optimizer, device, scheduler, n_examples)
4 correct_predictions = 0
5
----> 6 for i in data_loader:
7 input_ids = i['input_ids'].to(device)
8 attention_mask = i['attention_mask'].to(device)
~\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in __next__(self)
361
362 def __next__(self):
--> 363 data = self._next_data()
364 self._num_yielded += 1
365 if self._dataset_kind == _DatasetKind.Iterable and \
~\Anaconda3\lib\site-packages\torch\utils\data\dataloader.py in _next_data(self)
401 def _next_data(self):
402 index = self._next_index() # may raise StopIteration
--> 403 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
404 if self._pin_memory:
405 data = _utils.pin_memory.pin_memory(data)
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\fetch.py in fetch(self, possibly_batched_index)
45 else:
46 data = self.dataset[possibly_batched_index]
---> 47 return self.collate_fn(data)
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
72 return batch
73 elif isinstance(elem, container_abcs.Mapping):
---> 74 return {key: default_collate([d[key] for d in batch]) for key in elem}
75 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
76 return elem_type(*(default_collate(samples) for samples in zip(*batch)))
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in <dictcomp>(.0)
72 return batch
73 elif isinstance(elem, container_abcs.Mapping):
---> 74 return {key: default_collate([d[key] for d in batch]) for key in elem}
75 elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
76 return elem_type(*(default_collate(samples) for samples in zip(*batch)))
~\Anaconda3\lib\site-packages\torch\utils\data\_utils\collate.py in default_collate(batch)
53 storage = elem.storage()._new_shared(numel)
54 out = elem.new(storage)
---> 55 return torch.stack(batch, 0, out=out)
56 elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
57 and elem_type.__name__ != 'string_':
RuntimeError: stack expects each tensor to be equal size, but got [160] at entry 0 and [376] at entry 5
解决方案
推荐阅读
- python - 如何理解 tensorflow 错误信息?
- mysql - 使用 Lambda 将 RDS 转储到 S3
- css - Angular Material + Theme Generator => `rgba($color, $alpha)` 的参数 `$color` 必须是颜色
- java - 使用 spring-social-twitter 1.1.2 直接发送消息
- r - 在 R 的 lapply() 中使用 eval() 定义变量
- c# - 从 Asp.Net Core _Layout 视图访问数据库
- python - Python读取windows磁盘扇区问题
- airflow - 哪个版本的 Apache Airflow 包含实验性 API?
- c# - 查找列中包含数据的最后一个单元格
- python - 在 python 中使用 PIL 以相同的裁剪大小裁剪整个图像