python-3.x - PyTorch:“TypeError:在 DataLoader 工作进程 0 中捕获 TypeError。”
问题描述
我正在尝试实现 RoBERTa 模型进行情绪分析。首先,我声明了 GPReviewDataset 来创建一个 PyTorch 数据集。
MAX_LEN = 160
class GPReviewDataset(Dataset):
def __init__(self, reviews, targets, tokenizer, max_len):
self.reviews = reviews
self.targets = targets
self.tokenizer = tokenizer
self.max_len = max_len
def __len__(self):
return len(self.reviews)
def __getitem__(self, item):
review = str(self.reviews[item])
target = self.targets[item]
encoding = self.tokenizer.encode_plus(
review,
add_special_tokens=True,
max_length=self.max_len,
return_token_type_ids=False,
pad_to_max_length=True,
return_attention_mask=True,
return_tensors='pt',
)
return {
'review_text': review,
'input_ids': encoding['input_ids'].flatten(),
'attention_mask': encoding['attention_mask'].flatten(),
'targets': torch.tensor(target, dtype=torch.long)
}
接下来,我实现create_data_loader
创建几个数据加载器。这是一个辅助函数:
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = GPReviewDataset(
reviews=df.text.to_numpy(),
targets=df.sentiment.to_numpy(),
tokenizer=tokenizer,
max_len=max_len
)
return DataLoader(
ds,
batch_size=batch_size,
num_workers=4
)
BATCH_SIZE = 16
train_data_loader = create_data_loader(df_train, tokenizer, MAX_LEN, BATCH_SIZE)
val_data_loader = create_data_loader(df_val, tokenizer, MAX_LEN, BATCH_SIZE)
test_data_loader = create_data_loader(df_test, tokenizer, MAX_LEN, BATCH_SIZE)
dt = next(iter(train_data_loader))
但是,当我运行此代码时,它会停止并给出以下错误:
TypeError Traceback (most recent call last)
<ipython-input-35-a673c0794f60> in <module>()
----> 1 dt = next(iter(train_data_loader))
3 frames
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _next_data(self)
1083 else:
1084 del self._task_info[idx]
-> 1085 return self._process_data(data)
1086
1087 def _try_put_index(self):
/usr/local/lib/python3.6/dist-packages/torch/utils/data/dataloader.py in _process_data(self, data)
1109 self._try_put_index()
1110 if isinstance(data, ExceptionWrapper):
-> 1111 data.reraise()
1112 return data
1113
/usr/local/lib/python3.6/dist-packages/torch/_utils.py in reraise(self)
426 # have message field
427 raise self.exc_type(message=msg)
--> 428 raise self.exc_type(msg)
429
430
TypeError: Caught TypeError in DataLoader worker process 0.
Original Traceback (most recent call last):
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/worker.py", line 198, in _worker_loop
data = fetcher.fetch(index)
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in fetch
data = [self.dataset[idx] for idx in possibly_batched_index]
File "/usr/local/lib/python3.6/dist-packages/torch/utils/data/_utils/fetch.py", line 44, in <listcomp>
data = [self.dataset[idx] for idx in possibly_batched_index]
File "<ipython-input-18-1e537ce5a428>", line 25, in __getitem__
'targets': torch.tensor(target, dtype=torch.long)
TypeError: new(): invalid data type 'str'
我不明白为什么会这样,谁能帮我解释一下。
解决方案
您需要将您的类定义为整数。我假设您正在处理分类问题。但看起来您已将类定义为字符串。您需要将您的类从字符串转换为整数。例如,如果 df.sentiment 对应于正数,则必须用 0 表示,或者 df.sentiment 对应于负数,则需要在新列中用 1 表示。
def to_int_sentiment(label):
if label == "positive":
return 0
elif label == "negative":
return 1
df['int_sentiment'] = df.sentiment.apply(to_int_sentiment)
然后你应该使用列 df.int_sentiment 而不是 df.sentiment。所以你必须改变 create_data_loader 函数如下。
def create_data_loader(df, tokenizer, max_len, batch_size):
ds = GPReviewDataset(
reviews=df.text.to_numpy(),
targets=df.int_sentiment.to_numpy(),
tokenizer=tokenizer,
max_len=max_len
)
return DataLoader(
ds,
batch_size=batch_size,
num_workers=4
)
推荐阅读
- heroku - 如何修复 Multer-Gridfs-Storage 错误“创建存储引擎时出错。必须提供至少一个 url 或 db 选项”?
- python - 使用 Spyder 4.1.4 和 python 3.8 进行正确的验证循环
- python - Python:使用 xlrd 库从 excel 电子表格中读取数据给了我不正确的行数
- json - 使用 Scala 在 Spark 中将 JSON 对象更改为 JSON 列表
- javascript - 按下 esc 键时有没有办法滚动到顶部?
- javascript - 如何更改 React Native 中特定列的样式?
- flutter - 如何使用音量或电源按钮双击在后台打开颤振应用程序?
- maven - 无法解决来自 POM 中本地仓库的依赖关系
- rust - 如何读取前 N 行,然后读取字节?
- android-recyclerview - DiffUtil areContentsTheSame() 在 RecycleViewAdapter 中的 List 内容更新后总是返回 true