python - 对 Huggingface Transformers 序列分类的 predict() 输出感到困惑
问题描述
下面的大部分代码都取自这个 huggingface 文档页面,用于 tensorflow 代码选择。让我感到困惑的是,在对几个新句子的预训练模型进行微调并predict
在两个测试集句子上运行之后,我得到predict()
的输出是 16x2 数组。
x2 是有道理的,因为我有两个类 (0,1),但是当我将 2 个(不是 16 个)序列的测试集传递给“SequenceClassification”模型时,为什么长度为 16?如何获得两个测试集序列的预测类?(ps我从logits转换为预测概率没有问题,只是对输出的形状感到困惑)。
下面的可重现代码示例。也可以随意在 google colab 环境中单步执行代码
from transformers import DistilBertTokenizerFast
from transformers import TFDistilBertForSequenceClassification
import tensorflow as tf
# set up arbitrary example data
train_txt = ['this sentence is about dinosaurs', 'this also mentions dinosaurs', 'this does not']
test_txt = ['the land before time was cool', 'alligators are basically dinosaurs']
train_labels = [1,1,0]
test_labels = [1,1]
# convert sentence lists to Distilbert Encodings and then TF Datasets
tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-uncased')
train_encodings = tokenizer([str(s) for s in train_txt], truncation=True, padding=True)
test_encodings = tokenizer([str(s) for s in test_txt], truncation=True, padding=True)
train_dataset = tf.data.Dataset.from_tensor_slices((
dict(train_encodings),
train_labels
))
test_dataset = tf.data.Dataset.from_tensor_slices((
dict(test_encodings),
test_labels
))
# Fine-tune pretrained Distilbert Classifier on our data
model = TFDistilBertForSequenceClassification.from_pretrained('distilbert-base-uncased')
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer, loss=model.compute_loss) # can also use any keras loss fn
model.fit(train_dataset.shuffle(1000).batch(3), epochs=3, batch_size=3)
# Generate test-set predictions
test_preds = model.predict(test_dataset)
test_preds
输出:
>test_preds
TFSequenceClassifierOutput([('logits', array([[ 0.1527334 , 0.17010647],
[ 0.10007463, 0.15664947],
[-0.10294056, 0.18813357],
[-0.05231615, 0.1587314 ],
[-0.11520502, 0.16303074],
[ 0.00855697, 0.13974288],
[-0.17962483, 0.12381783],
[ 0.05765227, 0.04970012],
[ 0.1527334 , 0.17010647],
[-0.12754977, 0.11164709],
[-0.00847345, 0.12885672],
[-0.01731028, 0.13520113],
[-0.08433925, 0.16828224],
[-0.20086896, 0.08963215],
[ 0.05765227, 0.04970012],
[ 0.02467203, 0.15794128]], dtype=float32))])
解决方案
结果将是每个数组中最大值的索引
例子 :
**[ 0.1527334 , 0.17010647] --> 1(0.17010647的索引)
[ 0.05765227, 0.04970012] --> 0(0.05765227 的索引)**
因此,您可以使用下一行代码运行:
class_preds = np.argmax(test_preds["logits"])
推荐阅读
- java - 如何手动填充表格/网格
- javascript - 如何在 Select 标签中放置 2000 多个项目?
- html - 为 CSS 动画添加悬停效果
- c# - 如何在c#中反序列化来自http发布的json文件的数据
- javascript - 为什么 classList.contain() 总是返回 false?
- reactjs - 将类组件函数转换为函数组件函数
- r - 当统计测试使用二进制变量时,将 p 值输出保存到数据帧
- javascript - 为什么 document.addEventListener('DOMContentLoaded', Store.DisplayBooks) 不起作用
- python - 尝试使用来自 csv 文件的 4 个输入来模拟一个 FMU 时出错
- google-cloud-platform - 谷歌数据融合如何使用条件插件触发分支?