python - Model() 为参数“nr_class”获取了多个值 - SpaCy 多分类模型(BERT 集成)
问题描述
嗨,我正在使用新的 SpaCy 模型实现多分类模型(5 类)en_pytt_bertbaseuncased_lg
。新管道的代码在这里:
nlp = spacy.load('en_pytt_bertbaseuncased_lg')
textcat = nlp.create_pipe(
'pytt_textcat',
config={
"nr_class":5,
"exclusive_classes": True,
}
)
nlp.add_pipe(textcat, last = True)
textcat.add_label("class1")
textcat.add_label("class2")
textcat.add_label("class3")
textcat.add_label("class4")
textcat.add_label("class5")
培训代码如下,基于此处的示例(https://pypi.org/project/spacy-pytorch-transformers/):
def extract_cat(x):
for key in x.keys():
if x[key]:
return key
# get names of other pipes to disable them during training
n_iter = 250 # number of epochs
train_data = list(zip(train_texts, [{"cats": cats} for cats in train_cats]))
dev_cats_single = [extract_cat(x) for x in dev_cats]
train_cats_single = [extract_cat(x) for x in train_cats]
cats = list(set(train_cats_single))
recall = {}
for c in cats:
if c is not None:
recall['dev_'+c] = []
recall['train_'+c] = []
optimizer = nlp.resume_training()
batch_sizes = compounding(1.0, round(len(train_texts)/2), 1.001)
for i in range(n_iter):
random.shuffle(train_data)
losses = {}
batches = minibatch(train_data, size=batch_sizes)
for batch in batches:
texts, annotations = zip(*batch)
nlp.update(texts, annotations, sgd=optimizer, drop=0.2, losses=losses)
print(i, losses)
所以我的数据结构如下所示:
[('TEXT TEXT TEXT',
{'cats': {'class1': False,
'class2': False,
'class3': False,
'class4': True,
'class5': False}}), ... ]
我不确定为什么会出现以下错误:
TypeError Traceback (most recent call last)
<ipython-input-32-1588a4eadc8d> in <module>
21
22
---> 23 optimizer = nlp.resume_training()
24 batch_sizes = compounding(1.0, round(len(train_texts)/2), 1.001)
25
TypeError: Model() got multiple values for argument 'nr_class'
编辑:
如果我取出 nr_class 参数,我会在这里得到这个错误:
ValueError: operands could not be broadcast together with shapes (1,2) (1,5)
我实际上认为这会发生,因为我没有指定 nr_class 参数。那是对的吗?
解决方案
这是我们发布的最新版本中的回归spacy-pytorch-transformers
。为此表示歉意!
根本原因是,这又是一个邪恶的例子**kwargs
。我期待着改进 spaCy API 以防止将来出现这些问题。
您可以在此处看到违规行:https ://github.com/explosion/spacy-pytorch-transformers/blob/c1def95e1df783c69bff9bc8b40b5461800e9231/spacy_pytorch_transformers/pipeline/textcat.py#L71 。我们提供nr_class
位置参数,它与您在配置期间传入的显式参数重叠。
为了解决这个问题,您可以简单地从您传入nr_class
的 dict 中删除密钥。config
spacy.create_pipe()
推荐阅读
- asp.net-core - 将 Azure AD 电子邮件记录到 ASP.NET 身份表中
- c++ - 如何在遍历向量时分配变量?
- c# - 尝试通过可视化代码托管静态站点,但构建后出现错误
- python - 从 Google Cloud Function python3.7 连接到 Google Cloud SQL
- javascript - iPhone Safari中iframe内的锚链接不起作用
- python - 将 Pandas 时间序列数据帧转换为 3D 数组
- parsing - Rust 编译器如何标记泛型中的“>”和“>>”?
- laravel - 如何从 IONIC 4 中的服务中接收价值?
- java - add(index, element) 方法如何使用 LinkedList 在幕后工作?
- angular - Angular HTTP 调用根据状态将值返回到两个属性中