python - Python:BERT 的文本分类停留在 0%
问题描述
作为文本分类的新手,我正在使用名为的kaggle 数据imdb_folds.csv
,然后在显示此消息时我的代码运行卡住了:
0%| | 0/1250 [00:00<?, ?it/s]
我的代码是:
## Import Packages
import tez
import torch
import torch.nn as nn
import transformers
from transformers import AdamW, get_linear_schedule_with_warmup
from sklearn import metrics
import pandas as pd
## Create a data loader using a class named BERTDataset
class BERTDataset:
def __init__(self, texts, targets, max_len = 64):
self.texts = texts
self.targets = targets
self.tokenizer = transformers.BertTokenizer.from_pretrained(
"bert-base-uncased", #model name
do_lower_case = False
)
self.max_len = max_len
# length function
def __len__(self):
return len(self.texts)
def __getitem__(self, idx):
text = str(self.texts[idx])
inputs = self.tokenizer.encode_plus(
texts,
None,
add_special_tokens = True,
max_length = self.max_len,
padding = "max_length",
truncation = True
)
resp = {
"ids": torch.tensor(inputs["input_ids"], dtype = torch.long),
"mask": torch.tensor(inputs["attention_mask"], dtype = torch.long),
"token_type_ids": torch.tensor(inputs["token_type_ids"], dtype = torch.long),
"targets": torch.tensor(self.targets[idx], dtype = torch.float),
## for multiclass classification, convert change dtype from torch.float to torch.long
#"targets": torch.tensor(self.targets[idx], dtype = torch.long),
}
return resp
## Build the model
class TextModel(tez.Model):
def __init__(self, num_classes, num_train_steps):
super().__init__()
self.bert = transformers.BertModel.from_pretrained(
"bert-base-uncased", return_dict = False
)
self.bert_drop = nn.Dropout(0.3)
self.out = nn.Linear(768, num_classes) # num_classes is 1 or > 1 if it is binary multiclass classification respectively
self.num_train_steps = num_train_steps
# optimizer
def fetch_optimizer(self):
opt = AdamW(self.parameters(), lr = 1e-4)
return opt
# scheduler
def fetch_scheduler(self):
sch = get_linear_schedule_with_warmup(
self.optimizer, num_warmup_steps = 0, num_training_steps = self.num_train_steps
)
return sch
# loss
def loss(self, outputs, targets):
return nn.BCEWithLogitsLoss()(outputs, targets.view(-1, 1))
## include the next line if you have multiclass classification
# return nn.CrossEntropyLoss()(outputs, targets)
# calculate accuracy
def monitor_metrics(self, outputs, targets):
outputs = torch.sigmoid(outputs).cpu().detach().numpy() >= 0.5
targets = targets.cpu().detach().numpy()
return {"accuracy": metrics.accuracy_score(targets, outputs)}
# forward function
def forward(self, ids, mask, token_type_ids, targets = None):
_, x = self.bert(ids, attention_mask = mask, token_type_ids = token_type_ids)
x = self.bert_drop(x)
x = self.out(x)
if targets is not None:
loss = self.loss(outputs, targets)
met = self.monitor_metrics(outputs, targets)
return x, loss, met
return x, 0, {} # if there is no target, return 0
## Read Dataset
def train_model(fold):
df = pd.read_csv("imdb_folds.csv") # read file
df_train = df[df.kfold != fold].reset_index(drop=True)
df_valid = df[df.kfold == fold].reset_index(drop=True)
train_dataset = BERTDataset(df_train.review.values, df_train.sentiment.values)
valid_dataset = BERTDataset(df_valid.review.values, df_valid.sentiment.values)
# n_train_steps = int(len(df_train) / TRAIN_BS * EPOCHS)
n_train_steps = int(len(df_train) / 32 * 10)
model = TextModel(num_classes = 1, num_train_steps = n_train_steps)
es = tez.callbacks.EarlyStopping(monitor = "valid_loss", patience = 3, model_path = "model.bin")
model.fit(
train_dataset,
valid_dataset = valid_dataset,
device = "cuda",
epochs = 10,
train_bs = 32,
callbacks = [es],
)
if __name__=="__main__":
train_model(fold = 0)
我在stackoverflow上发现了类似的问题,但无法调整解决方案。任何解决此问题的帮助将不胜感激。
解决方案
推荐阅读
- python - Why does this print nothing in Python 3?
- c++ - Using libcurl in g++
- python - Looking for Words in a List with Similar Letters
- google-play - Create New Release is greyed out in Google Play Console
- ios - Local files not opening on iOS simulator after restart
- flutter - type 'Future
' is not a subtype of type 'Widget' - python - "TypeError: 'type' object is not subscriptable" in a function signature
- c - Copying specific number of characters from a string to another
- python - Random seed not performing as expected
- c++ - 具有不同值的 Rcpp 函数填充矩阵