python - UnparsedFlagAccessError:在解析标志之前尝试访问标志 --preserve_unused_tokens。BERT
问题描述
我想使用 Bert 语言模型来训练多类文本分类任务。以前我使用 LSTM 训练没有任何错误,但 Bert 给了我这个错误。我得到这个错误如下,我真的不知道如何解决它,有人可以帮助我吗?
不幸的是,在 keras 库中使用 Bert 的文档很少。
!wget --quiet https://raw.githubusercontent.com/tensorflow/models/master/official/nlp/bert/tokenization.py
import tensorflow_hub as hub
from bert import tokenization
module_url = 'https://tfhub.dev/tensorflow/bert_en_uncased_L-12_H-768_A-12/2'
bert_layer = hub.KerasLayer(module_url, trainable=True)
vocab_file = bert_layer.resolved_object.vocab_file.asset_path.numpy()
do_lower_case = bert_layer.resolved_object.do_lower_case.numpy()
tokenizer = tokenization.FullTokenizer(vocab_file, do_lower_case)
def bert_encode(texts, tokenizer, max_len=512):
all_tokens = []
all_masks = []
all_segments = []
for text in texts:
text = tokenizer.tokenize(text)
text = text[:max_len-2]
input_sequence = ["[CLS]"] + text + ["[SEP]"]
pad_len = max_len - len(input_sequence)
tokens = tokenizer.convert_tokens_to_ids(input_sequence) + [0] * pad_len
pad_masks = [1] * len(input_sequence) + [0] * pad_len
segment_ids = [0] * max_len
all_tokens.append(tokens)
all_masks.append(pad_masks)
all_segments.append(segment_ids)
return np.array(all_tokens), np.array(all_masks), np.array(all_segments)
def build_model(bert_layer, max_len=512):
input_word_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_word_ids")
input_mask = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="input_mask")
segment_ids = tf.keras.Input(shape=(max_len,), dtype=tf.int32, name="segment_ids")
pooled_output, sequence_output = bert_layer([input_word_ids, input_mask, segment_ids])
clf_output = sequence_output[:, 0, :]
net = tf.keras.layers.Dense(64, activation='softmax')(clf_output)
net = tf.keras.layers.Dropout(0.2)(net)
net = tf.keras.layers.Dense(32, activation='softmax')(net)
net = tf.keras.layers.Dropout(0.2)(net)
out = tf.keras.layers.Dense(3, activation='softmax')(net)
model = tf.keras.models.Model(inputs=[input_word_ids, input_mask, segment_ids], outputs=out)
model.compile(tf.keras.optimizers.Adam(lr=1e-5), loss='categorical_crossentropy', metrics=['accuracy'])
return model
max_len = 150
train_input = bert_encode(data.text_cleaned, tokenizer, max_len=max_len)
错误如下:
UnparsedFlagAccessError Traceback (most recent call last)
<ipython-input-175-fd64df42591d> in <module>()
1 import sys
2 max_len = 150
----> 3 train_input = bert_encode(o.text_cleaned, tokenizer, max_len=max_len)
4 frames
/usr/local/lib/python3.7/dist-packages/absl/flags/_flagvalues.py in __getattr__(self, name)
496 # get too much noise.
497 logging.error(error_message)
--> 498 raise _exceptions.UnparsedFlagAccessError(error_message)
499
500 def __setattr__(self, name, value):
UnparsedFlagAccessError: Trying to access flag --preserve_unused_tokens before flags were parsed.
解决方案
基于这个问题,您必须将 bert-tensorflow 降级到 1.0.1。检查此答案以找到解决方案。如果您按照本教程降级 bert-tensorflow 并按照建议使用,因为在 python 代码中!wget --quiet https://raw.githubusercontent.com/tensorflow/models/master/official/nlp/bert/tokenization.py
,作者已将. 之后代码编译成功。如果你想要别的,请联系我。tf.gfile.GFile(vocab_file, "r")
tf.io.gfile.Gfile(vocab_file, "r")
推荐阅读
- javascript - 如何让 JQuery json 函数在 javascript 之前运行
- html - 在 Codeignitor 上加载样式表
- angular - 无法从 Kibana/Elasticsearch 上的日志行或脚本字段获取定时值
- sql-server - SQL Server 2012 无法使用 ssms 远程连接到命名实例
- msf4j - msf4j 对 cors 的支持
- r - 使用 dplyr 计算多个分组变量
- android - 即时运行 Android Studio 的目标设备 API 级别(API 1)太低
- jsp - Tomcat需要JDK吗
- java - PySpark 中不存在方法 showString([class java.lang.Integer, class java.lang.Integer, class java.lang.Boolean])
- javascript - 选定的复选框项目移动到另一个列表并在 li 标签内添加输入字段?