python-3.x - 无法使用自定义数据集:AttributeError:“列表”对象没有属性“键”
问题描述
我正在尝试使用 Huggingface Transformers 使用自定义数据集训练分类模型,但我不断收到错误。最后一个错误似乎可以解决,但我不明白如何解决。我究竟做错了什么?
我用
model_name = "dbmdz/bert-base-italian-uncased"
tokenizer = AutoTokenizer.from_pretrained(model_name, do_lower_case = True)
def encode_data(texts):
return tokenizer.batch_encode_plus(
texts,
add_special_tokens=True,
return_attention_mask=True,
padding = True,
truncation=True,
max_length=200,
return_tensors='pt'
)
然后我创建我的数据集
import torch
class my_Dataset(torch.utils.data.Dataset):
def __init__(self, encodings, labels):
self.encodings = encodings
self.labels = torch.tensor(labels)
def __getitem__(self, idx):
item = {key: val[idx] for key, val in self.encodings.items()}
item['labels'] = self.labels[idx]
print(item)
return item
def __len__(self):
return len(self.labels)
所以我有
encoded_data_train = encode_data(df_train['text'].tolist())
encoded_data_val = encode_data(df_val['text'].tolist())
encoded_data_test = encode_data(df_test['text'].tolist())
dataset_train = my_Dataset(encoded_data_train, df_train['labels'].tolist())
dataset_val = my_Dataset(encoded_data_val, df_val['labels'].tolist())
dataset_test = my_Dataset(encoded_data_test, df_test['labels'].tolist())
然后我启动我的教练
from transformers import AutoConfig, TrainingArguments, DataCollatorWithPadding, Trainer
training_args = TrainingArguments(
output_dir='/trial',
learning_rate=1e-6,
do_train=True,
do_eval=True,
evaluation_strategy='epoch',
num_train_epochs=10,
per_device_train_batch_size=8,
per_device_eval_batch_size=8,
warmup_steps=0,
weight_decay=0.2,
logging_dir="./logs",
)
num_labels = len(label_dict)
model = AutoModelForSequenceClassification.from_pretrained(model_name,num_labels = num_labels)
trainer = Trainer(
model=model,
args=training_args,
data_collator=DataCollatorWithPadding(tokenizer),
tokenizer= tokenizer,
train_dataset=dataset_train,
eval_dataset=dataset_val,
)
最后我训练
trainer.train()
这是我得到的错误
AttributeErrorTraceback (most recent call last)
<ipython-input-22-5d018b4b061d> in <module>
----> 1 trainer.train()
/opt/conda/lib/python3.8/site-packages/transformers/trainer.py in train(self, resume_from_checkpoint, trial, **kwargs)
1032 self.control = self.callback_handler.on_epoch_begin(self.args, self.state, self.control)
1033
-> 1034 for step, inputs in enumerate(epoch_iterator):
1035
1036 # Skip past any already trained steps if resuming training
/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in __next__(self)
433 if self._sampler_iter is None:
434 self._reset()
--> 435 data = self._next_data()
436 self._num_yielded += 1
437 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.8/site-packages/torch/utils/data/dataloader.py in _next_data(self)
473 def _next_data(self):
474 index = self._next_index() # may raise StopIteration
--> 475 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
476 if self._pin_memory:
477 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
45 else:
46 data = self.dataset[possibly_batched_index]
---> 47 return self.collate_fn(data)
/opt/conda/lib/python3.8/site-packages/transformers/data/data_collator.py in __call__(self, features)
116
117 def __call__(self, features: List[Dict[str, Union[List[int], torch.Tensor]]]) -> Dict[str, torch.Tensor]:
--> 118 batch = self.tokenizer.pad(
119 features,
120 padding=self.padding,
/opt/conda/lib/python3.8/site-packages/transformers/tokenization_utils_base.py in pad(self, encoded_inputs, padding, max_length, pad_to_multiple_of, return_attention_mask, return_tensors, verbose)
2558 if self.model_input_names[0] not in encoded_inputs:
2559 raise ValueError(
-> 2560 "You should supply an encoding or a list of encodings to this method"
2561 f"that includes {self.model_input_names[0]}, but you provided {list(encoded_inputs.keys())}"
2562 )
AttributeError: 'list' object has no attribute 'keys'
我做错了什么?我也尝试过使用
import torch
from torch.utils.data import TensorDataset
dataset_train = TensorDataset(encoded_data_train['input_ids'], encoded_data_train['attention_mask'], torch.tensor(df_train['labels'].tolist()))
dataset_test = TensorDataset(encoded_data_test['input_ids'], encoded_data_test['attention_mask'], torch.tensor(df_test['labels'].tolist()))
dataset_val = TensorDataset(encoded_data_val['input_ids'], encoded_data_val['attention_mask'], torch.tensor(df_val['labels'].tolist()))
得到同样的错误。我正在使用火炬 == 1.7.1 和变压器 == 4.4.2
在第一条评论之后编辑。这是一个5维的例子encoded_data_train
{'input_ids': tensor([[ 102, 927, 9534, 30936, 2729, 29505, 123, 11805, 7427, 10587,
9703, 927, 9534, 30936, 2719, 10118, 2321, 784, 366, 113,
3627, 7763, 9433, 223, 148, 30937, 4051, 3400, 4011, 20005,
6079, 784, 366, 7809, 11967, 192, 3497, 784, 366, 7809,
11967, 192, 3497, 784, 366, 7809, 11967, 192, 3497, 784,
366, 7809, 11967, 192, 3497, 714, 927, 9534, 30936, 2729,
29505, 123, 11805, 7427, 260, 480, 1556, 152, 7113, 20734,
151, 143, 784, 366, 113, 3627, 7763, 19638, 159, 1233,
1674, 5442, 119, 9433, 223, 148, 30937, 135, 642, 829,
2250, 223, 743, 151, 143, 14572, 13799, 1767, 28915, 12057,
12342, 784, 366, 113, 9703, 927, 9534, 30936, 9480, 10125,
8418, 3726, 8379, 2955, 119, 1006, 30946, 8897, 123, 6423,
115, 1601, 544, 30938, 3013, 160, 30941, 137, 124, 14118,
30936, 193, 2701, 19214, 1457, 2701, 1864, 409, 19727, 13305,
6423, 115, 10389, 13908, 127, 4092, 14079, 1601, 2009, 24286,
23419, 103],
[ 102, 10587, 2130, 182, 8022, 2719, 10118, 132, 30976, 30943,
17961, 5123, 3292, 3627, 11532, 2719, 10118, 132, 30976, 30943,
17961, 5123, 3292, 3627, 11532, 2719, 10118, 201, 17961, 5123,
3292, 3627, 11532, 6354, 480, 1556, 28951, 17586, 113, 12699,
135, 480, 1556, 7347, 677, 135, 3110, 103, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 102, 2719, 10118, 6729, 6530, 10754, 11752, 10272, 11752, 119,
4200, 209, 30944, 19919, 2201, 5754, 642, 838, 15657, 6156,
30941, 148, 30937, 2201, 7305, 642, 6331, 3348, 30937, 170,
148, 30937, 2463, 103, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 102, 780, 30938, 18834, 2336, 2719, 10118, 8823, 784, 366,
113, 135, 1543, 2080, 1233, 20734, 316, 1767, 1542, 2771,
152, 25899, 119, 8823, 119, 4472, 784, 366, 113, 137,
1031, 510, 7763, 123, 21478, 3200, 111, 985, 119, 1670,
4999, 290, 30941, 119, 6951, 12042, 106, 1542, 135, 245,
30942, 26609, 199, 983, 119, 261, 28040, 8142, 148, 30937,
150, 143, 917, 1621, 7161, 111, 26609, 8217, 3723, 12510,
290, 30941, 119, 8886, 30934, 9798, 106, 204, 30942, 5807,
155, 1176, 213, 12057, 189, 387, 4953, 214, 2643, 4429,
123, 11224, 3096, 193, 143, 8823, 387, 2353, 2009, 193,
982, 176, 18789, 299, 8292, 553, 9798, 8886, 30934, 20853,
490, 4802, 19222, 642, 3829, 1455, 26321, 167, 148, 30937,
11498, 123, 103, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0],
[ 102, 10587, 491, 5462, 7664, 22790, 2719, 10118, 8498, 408,
24484, 112, 491, 5462, 7664, 22790, 3671, 135, 341, 1011,
299, 18239, 113, 143, 575, 8498, 265, 669, 113, 3850,
16465, 480, 283, 28951, 810, 21223, 103, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0]]), 'token_type_ids': tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]])}
和相应的结果dataset_train.__getitem__(0)
{'input_ids': tensor([ 102, 927, 9534, 30936, 2729, 29505, 123, 11805, 7427, 10587,
9703, 927, 9534, 30936, 2719, 10118, 2321, 784, 366, 113,
3627, 7763, 9433, 223, 148, 30937, 4051, 3400, 4011, 20005,
6079, 784, 366, 7809, 11967, 192, 3497, 784, 366, 7809,
11967, 192, 3497, 784, 366, 7809, 11967, 192, 3497, 784,
366, 7809, 11967, 192, 3497, 714, 927, 9534, 30936, 2729,
29505, 123, 11805, 7427, 260, 480, 1556, 152, 7113, 20734,
151, 143, 784, 366, 113, 3627, 7763, 19638, 159, 1233,
1674, 5442, 119, 9433, 223, 148, 30937, 135, 642, 829,
2250, 223, 743, 151, 143, 14572, 13799, 1767, 28915, 12057,
12342, 784, 366, 113, 9703, 927, 9534, 30936, 9480, 10125,
8418, 3726, 8379, 2955, 119, 1006, 30946, 8897, 123, 6423,
115, 1601, 544, 30938, 3013, 160, 30941, 137, 124, 14118,
30936, 193, 2701, 19214, 1457, 2701, 1864, 409, 19727, 13305,
6423, 115, 10389, 13908, 127, 4092, 14079, 1601, 2009, 24286,
23419, 103]), 'token_type_ids': tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
0, 0, 0, 0, 0, 0, 0, 0]), 'attention_mask': tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1]), 'labels': tensor(5)}
解决方案
推荐阅读
- bash - curl 命令在 shell 脚本中使用时无法读取变量会引发错误 curl: (26) 无法从文件/应用程序打开/读取本地数据
- javascript - 在 p5js 中使用 mousePressed() 函数
- python - SymPy 已安装但默认 Python 版本未找到
- javascript - 将数组和对象转换为自定义对象
- python - 如何在序列化为熊猫数据框时展平嵌套数据类?
- javascript - 在打字稿中正确使用异步(角度)
- amazon-ec2 - 在 InfluxDB v2 中接收 Gatling 结果
- github - 了解存储库 gpt 转换器
- javascript - Chart.js 属性“类型”的类型不兼容。类型 'string' 不可分配给类型 '"line" | “酒吧” | “分散”
- bash - Intellij IDEA 无法打开本地终端 java.util.concurrent.ExecutionException: 无法在 linux 中启动