python - 使用bidaf模型提取答案的置信度得分
问题描述
我正在做一个深度学习项目,我使用双向注意力流模型(allennlp 预训练模型)来制作问答系统。它使用小队数据集。bidaf 模型从段落中提取答案范围。有什么方法可以确定模型提取的答案的置信度得分(准确性)或任何其他指标?我使用了 allennlp 包中的子命令评估,但它仅在测试后确定模型的分数。我希望有一种更简单的方法可以使用其他此类命令来解决问题。在下面附上代码和终端输出。
from rake_nltk import Rake
from string import punctuation
from nltk.corpus import stopwords
from allennlp.predictors.predictor import Predictor
import spacy
import wikipedia
import re
import requests
from requests_html import HTMLSession
from bs4 import BeautifulSoup
import traceback
from nltk.stem import SnowballStemmer
from nltk.util import ngrams
from math import log10
from flask import Flask, request, jsonify, render_template
from gevent.pywsgi import WSGIServer
import time
import multiprocessing as mp
from gtts import gTTS
import os
NLP = spacy.load('en_core_web_md')
stop = stopwords.words('english')
symbol = r"""!#$%^&*();:\n\t\\\"!\{\}\[\]<>-\?"""
stemmer = SnowballStemmer('english')
wikipedia.set_rate_limiting(True)
session = HTMLSession()
results = 5
try:
predictor = Predictor.from_path("bidaf-model-2017.09.15-charpad.tar.gz")
except:
predictor = Predictor.from_path("https://storage.googleapis.com/allennlp-public-models/bidaf-elmo-model-2018.11.30-charpad.tar.gz")
try:
srl = Predictor.from_path('srl-model-2018.05.25.tar.gz')
except:
srl = Predictor.from_path('https://s3-us-west-2.amazonaws.com/allennlp/models/bert-base-srl-2019.06.17.tar.gz')
key = Rake(min_length=1, stopwords=stop, punctuations=punctuation, max_length=6)
wh_words = "who|what|how|where|when|why|which|whom|whose|explain".split('|')
stop.extend(wh_words)
session = HTMLSession()
output = mp.Queue()
def termFrequency(term, doc):
normalizeTermFreq = re.sub('[\[\]\{\}\(\)]', '', doc.lower()).split()
normalizeTermFreq = [stemmer.stem(i) for i in normalizeTermFreq]
dl = len(normalizeTermFreq)
normalizeTermFreq = ' '.join(normalizeTermFreq)
term_in_document = normalizeTermFreq.count(term)
#len_of_document = len(normalizeTermFreq )
#normalized_tf = term_in_document / len_of_document
normalized_tf = term_in_document
return normalized_tf, normalizeTermFreq, dl#, n_unique_term
def inverseDocumentFrequency(term, allDocs):
num_docs_with_given_term = 0
for doc in allDocs:
if term in doc:
num_docs_with_given_term += 1
if num_docs_with_given_term > 0:
total_num_docs = len(allDocs)
idf_val = log10(((total_num_docs+1) / num_docs_with_given_term))
term_split = term.split()
if len(term_split) == 3:
if len([term_split[i] for i in [0, 2] if term_split[i] not in stop]) == 2:
return idf_val*1.5
return idf_val
return idf_val
else:
return 0
def sent_formation(question, answer):
tags_doc = NLP(question)
tags_doc_cased = NLP(question.title())
tags_dict_cased = {i.lower_:i.pos_ for i in tags_doc_cased}
tags_dict = {i.lower_:i.pos_ for i in tags_doc}
question_cased = []
for i in question[:-1].split():
if tags_dict[i] == 'PROPN' or tags_dict[i] == 'NOUN':
question_cased.append(i.title())
else:
question_cased.append(i.lower())
question_cased.append('?')
question_cased = ' '.join(question_cased)
#del tags_dict,tags_doc, tags_doc_cased
pre = srl.predict(question_cased)
verbs = []
arg1 = []
for i in pre['verbs']:
verbs.append(i['verb'])
if 'B-ARG1' in i['tags']:
arg1.append((i['tags'].index('B-ARG1'), i['tags'].count('I-ARG1'))\
if not pre['words'][i['tags'].index('B-ARG1')].lower() in wh_words else \
(i['tags'].index('B-ARG2'), i['tags'].count('I-ARG2')))
arg1 = arg1[0] if arg1 else []
if not arg1:
verb_idx = pre['verbs'][0]['tags'].index('B-V')
verb = pre['words'][verb_idx] if pre['words'][verb_idx] != answer.split()[0].lower() else ''
subj_uncased = pre['words'][verb_idx+1:] if pre['words'][-1] not in symbol else \
pre['words'][verb_idx+1:-1]
else:
verb = ' '.join(verbs)
subj_uncased = pre['words'][arg1[0]:arg1[0]+arg1[1]+1]
conj = ''
if question.split()[0].lower() == 'when':
conj = ' on' if len(answer.split()) > 1 else ' in'
subj = []
for n, i in enumerate(subj_uncased):
if tags_dict_cased[i.lower()] == 'PROPN' and tags_dict[i.lower()] != 'VERB' or n == 0:
subj.append(i.title())
else:
subj.append(i.lower())
subj[0] = subj[0].title()
print(subj)
print(pre)
subj = ' '.join(subj)
sent = "{} {}{} {}.".format(subj, verb, conj, answer if answer[-1] != '.' else answer[:-1])
return sent
class extractAnswer:
def __init__(self):
self.wiki_error = (wikipedia.exceptions.DisambiguationError,
wikipedia.exceptions.HTTPTimeoutError,
wikipedia.exceptions.WikipediaException)
self.article_title = None
# symbol = """!#$%^&*();:\n\t\\\"!\{\}\[\]<>-\?"""
def extractAnswer_model(self, passage, question, s=0.4, e=0.3, wiki=False):
if type(passage) == list:
passage = " ".join(passage)
if not question[-1] == '?':
question = question+'?'
pre = predictor.predict(passage=passage, question=question)
if wiki:
if max(pre['span_end_probs']) > 0.5:
s = 0.12
elif max(pre['span_end_probs']) > 0.4:
s = 0.13
elif max(pre['span_end_probs']) > 0.35:
s = 0.14
if max(pre['span_start_probs']) > 0.5:
e = 0.12
elif max(pre['span_start_probs']) > 0.4:
e = 0.14
elif max(pre['span_start_probs']) > 0.3:
e = 0.15
if max(pre['span_start_probs']) > s and max(pre['span_end_probs']) > e:
key.extract_keywords_from_text(question)
ques_key = [stemmer.stem(i) for i in ' '.join(key.get_ranked_phrases())]
key.extract_keywords_from_text(passage)
pass_key = [stemmer.stem(i) for i in ' '.join(key.get_ranked_phrases())]
l = len(ques_key)
c = 0
for i in ques_key:
if i in pass_key:
c += 1
if c >= l/2:
print(max(pre['span_start_probs']),
max(pre['span_end_probs']))
if wiki:
return pre['best_span_str'], max(pre['span_start_probs']) + max(pre['span_end_probs'])
try:
ans = sent_formation(question, pre['best_span_str'])
except:
ans = pre['best_span_str']
print(traceback.format_exc())
return ans
print(ques_key, c, l)
print(max(pre['span_start_probs']), max(pre['span_end_probs']))
return 0, 0
else:
print(max(pre['span_start_probs']), max(pre['span_end_probs']), pre['best_span_str'])
return 0, 0
def wiki_search_api(self, query):
article_list = []
try:
article_list.extend(wikipedia.search(query, results=results))
print(article_list)
return article_list
except self.wiki_error:
params = {'search': query, 'profile': 'engine_autoselect',
'format': 'json', 'limit': results}
article_list.extend(requests.get('https://en.wikipedia.org/w/api.php?action=opensearch',
params=params).json()[1])
return article_list
except:
print('Wikipedia search error!')
print(traceback.format_exc())
return 0
def wiki_passage_api(self, article_title, article_list, output):
# Disambiguation_title = {}
try:
passage = wikipedia.summary(article_title)
output.put((article_title, self.passage_pre(passage)))
except wikipedia.exceptions.DisambiguationError as e:
print(e.options[0], e.options)
Disambiguation_pass = {}
for p in range(2 if len(e.options) > 1 else len(e.options)):
params = {'search':e.options[p], 'profile':'engine_autoselect', 'format':'json'}
article_url = requests.get('https://en.wikipedia.org/w/api.php?action=opensearch',
params=params).json()
if not article_url[3]:
continue
article_url = article_url[3][0]
r = session.get(article_url)
soup = BeautifulSoup(r.html.raw_html)
print(soup.title.string)
article_title_dis = soup.title.string.rsplit('-')[0].strip()
if article_title_dis in article_list:
print('continue')
continue
try:
url = "https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&exintro&explaintext&redirects=1&titles={}".format(article_title_dis)
passage = requests.get(url).json()['query']['pages']
for i in passage.keys():
if 'extract' in passage[i]:
Disambiguation_pass[article_title_dis] = self.passage_pre(passage[i]['extract'])
except wikipedia.exceptions.HTTPTimeoutError:
passage = wikipedia.summary(article_title_dis)
Disambiguation_pass[article_title_dis] = self.passage_pre(passage)
except:
Disambiguation_pass[article_title_dis] = ''
continue
output.put((article_title, Disambiguation_pass))
except:
output.put((article_title, ''))
print(traceback.format_exc())
def sorting(self, article, question, topic):
processes = [mp.Process(target=self.wiki_passage_api, args=(article[x], article, output))\
for x in range(len(article))]
for p in processes:
p.start()
for p in processes:
p.join(timeout=3)
results_p = [output.get() for p in processes]
article_list = []
passage_list = []
for i, j in results_p:
if type(j) != dict and j:
article_list.append(i)
passage_list.append(j)
elif type(j) == dict and j:
for k, l in j.items():
if l:
article_list.append(k)
passage_list.append(l)
normalize_passage_list = []
start = time.time()
keywords = " ".join(self.noun+self.ques_key+[topic.lower()])
keywords = re.sub('[{0}]'.format(symbol), ' ', keywords).split()
question = question+' '+topic
ques_tokens = [stemmer.stem(i.lower()) for i in question.split() \
if i.lower() not in wh_words]
print(ques_tokens)
keywords_bigram = [' '.join(i) for i in list(ngrams(ques_tokens, 2)) \
if i[0] not in stop and i[1] not in stop]
if len(ques_tokens) > 3:
keywords_trigram = [' '.join(i) for i in list(ngrams(ques_tokens, 3)) \
if (i[0] in stop) + (i[2] in stop) + (i[1] in stop) < 3]
else:
keywords_trigram = []
if len(ques_tokens) > 5:
keywords_4gram = [' '.join(i) for i in list(ngrams(ques_tokens, 4)) \
if (i[0] in stop) + (i[2] in stop) +(i[1] in stop)+(i[3] in stop) < 4]
else:
keywords_4gram = []
keywords_unigram = list(set([stemmer.stem(i.lower()) for i in keywords \
if i.lower() not in stop]))
keywords = keywords_unigram+list(set(keywords_bigram))+keywords_trigram+keywords_4gram
tf = []
if not passage_list:
return 0
pass_len = []
#n_u_t=[]
#key_dict = {i: keywords.count(i) for i in keywords}
print('Extraction complete')
#remove_pass={}
#for n,i in enumerate(passage_list):
#if len(i)<200 or not i:
#remove_pass[article_list[n]]=i
#print(n, article_list[n])
#passage_list=[i for i in passage_list if i not in remove_pass.values()]
#article_list=[i for i in article_list if i not in remove_pass.keys()]
passage_list_copy = passage_list.copy()
article_list_copy = article_list.copy()
for i in range(len(passage_list_copy)):
if passage_list.count(passage_list_copy[i]) > 1:
passage_list.remove(passage_list_copy[i])
article_list.remove(article_list_copy[i])
print('Copy:', article_list_copy[i])
del passage_list_copy
del article_list_copy
for n, i in enumerate(passage_list):
temp_tf = {}
c = 0
for j in keywords:
temp_tf[j], temp_pass, temp_len = termFrequency(j, i + ' ' + article_list[n])
if temp_tf[j]:
c += 1
normalize_passage_list.append(temp_pass)
pass_len.append(temp_len)
temp_tf['key_match'] = c
tf.append(temp_tf)
print(pass_len)
print(keywords)
idf = {}
for i in keywords:
idf[i] = inverseDocumentFrequency(i, normalize_passage_list)
#print(tf, idf)
tfidf = []
#b=0.333 #for PLN
b, k = 0.75, 1.2 #for BM25
avg_pass_len = sum(pass_len)/len(pass_len)
#pivot=sum(n_u_t)/len(n_u_t)
for n, i in enumerate(tf):
tf_idf = 0
#avg_tf=sum(i.values())/len(i)
key_match_ratio = i['key_match']/len(keywords)
for j in keywords:
#tf_idf+=idf[j]*((log(1+log(1+i[j])))/(1-b+(b*pass_len[n]/avg_pass_len))) #PLN
tf_idf += idf[j]*(((k+1)*i[j])/(i[j]+k*(1-b+(b*pass_len[n]/avg_pass_len)))) #BM25
tfidf.append(tf_idf*key_match_ratio)
tfidf = [i/sum(tfidf)*100 for i in tfidf if any(tfidf)]
if not tfidf:
return 0, 0, 0, 0, 0
print(tfidf)
print(article_list, len(passage_list))
if len(passage_list) > 1:
sorted_tfidf = sorted(tfidf, reverse=1)
idx1 = tfidf.index(sorted_tfidf[0])
passage1 = passage_list[idx1]
#article_title=
tfidf1 = sorted_tfidf[0]
idx2 = tfidf.index(sorted_tfidf[1])
passage2 = passage_list[idx2]
article_title = (article_list[idx1], article_list[idx2])
tfidf2 = sorted_tfidf[1]
else:
article_title = 0
tfidf2 = 0
if passage_list:
passage1 = passage_list[0]
tfidf1 = tfidf[0]
passage2 = 0
else:
passage1 = 0
passage2 = 0
tfidf1, tfidf2 = 0, 0
end = time.time()
print('TFIDF time:', end-start)
return passage1, passage2, article_title, tfidf1, tfidf2
def passage_pre(self, passage):
#passage=re.findall("[\da-zA-z\.\,\'\-\/\–\(\)]*", passage)
passage = re.sub('\n', ' ', passage)
passage = re.sub('\[[^\]]+\]', '', passage)
passage = re.sub('pronunciation', '', passage)
passage = re.sub('\\\\.+\\\\', '', passage)
passage = re.sub('{.+}', '', passage)
passage = re.sub(' +', ' ', passage)
return passage
def wiki(self, question, topic=''):
if not question:
return 0
question = re.sub(' +', ' ', question)
question = question.title()
key.extract_keywords_from_text(question)
self.ques_key = key.get_ranked_phrases()
doc = NLP(question)
self.noun = [str(i).lower() for i in doc.noun_chunks if str(i).lower() not in wh_words]
print(self.ques_key, self.noun)
question = re.sub('[{0}]'.format(symbol), ' ', question)
if not self.noun + self.ques_key:
return 0
article_list = None
question = question.lower()
if self.noun:
if len(self.noun) == 2 and len(" ".join(self.noun).split()) < 6:
#question1=question
self.noun = " ".join(self.noun).split()
if self.noun[0] in stop:
self.noun.pop(0)
self.noun = question[question.index(self.noun[0]):question.index(self.noun[-1]) \
+len(self.noun[-1])+1].split()
#del question1
print(self.noun)
article_list = self.wiki_search_api(' '.join(self.noun))
if self.ques_key and not article_list:
article_list = self.wiki_search_api(self.ques_key[0])
if not article_list:
article_list = self.wiki_search_api(' '.join(self.ques_key))
if not article_list:
print('Article not found on wikipedia.')
return 0, 0
article_list = list(set(article_list))
passage1, passage2, article_title, tfidf1, tfidf2 = self.sorting(article_list,
question, topic)
if passage1:
ans1, conf1 = self.extractAnswer_model(passage1, question, s=0.20, e=0.20, wiki=True)
else:
ans1, conf1 = 0, 0
if ans1:
conf2 = 0
if len(ans1) > 600:
print(ans1)
print('Repeat')
ans1, conf1 = self.extractAnswer_model(ans1, question, s=0.20, e=0.20, wiki=True)
threshhold = 0.3 if not ((tfidf1- tfidf2) <= 10) else 0.2
if round(tfidf1- tfidf2) < 5:
threshhold = 0
if (tfidf1- tfidf2) > 20:
threshhold = 0.35
if (tfidf1- tfidf2) > 50:
threshhold = 1
if (passage2 and conf1 < 1.5) or (tfidf1 - tfidf2) < 10:
ans2, conf2 = self.extractAnswer_model(passage2, question, s=0.20, e=0.20,
wiki=True) if passage2 else (0, 0)
title = 0
if round(conf1, 2) > round(conf2, 2) - threshhold:
print('ans1')
ans = ans1
title = article_title[0] if article_title else 0
else:
print('ans2')
title = article_title[1] if article_title else 0
ans = ans2
if not question[-1] == '?':
question = question+'?'
try:
ans = sent_formation(question, ans)
except:
print(traceback.format_exc())
print(ans, '\n', '\n', article_title)
return ans, title
extractor = extractAnswer()
app = Flask(__name__)
@app.route("/", methods=["POST", "get"])
@app.route("/ans")
def ans():
start = time.time()
question = request.args.get('question')
topic = request.args.get('topic')
passage = request.args.get('passage')
if not question:
return render_template('p.html')
if not topic:
topic = ''
if passage:
answer = extractor.extractAnswer_model(passage, question)
else:
answer, title = extractor.wiki(question, topic)
end = time.time()
if answer:
mytext = str(answer)
language = 'en'
myobj = gTTS(text=mytext, lang=language, slow=False)
myobj.save("welcome.mp3")
# prevName = 'welcome.mp3'
#newName = 'static/welcome.mp3'
#os.rename(prevName,newName)
return render_template('pro.html', answer=answer)
else:
return jsonify(Status='E', Answer=answer, Time=end-start)
@app.route("/audio_del/", methods=["POST", "get"])
def audio_del():
return render_template('p.html');
@app.route("/audio_play/", methods=["POST", "get"])
def audio_play():
os.system("mpg321 welcome.mp3")
return render_template('white.html')
if __name__ == "__main__":
PORT = 7091
HTTP_SERVER = WSGIServer(('0.0.0.0', PORT), app)
print('Running on',PORT, '...')
HTTP_SERVER.serve_forever()
![Output in the terminal for a question I've asked](https://i.stack.imgur.com/6pyv5.jpg)
解决方案
推荐阅读
- javascript - 错误:RelayParser:遇到一个或多个文档的重复定义
- tensorflow - why MSE on test set is very low and doesn't seem to evolve (not increasing after increasing epochs)
- asp.net-core - AspNet WebApi + Angular 7 项目架构
- python - 使用 Shopify assets API 修改 theme.liquid
- junit5 - TempDirectory 扩展:是否可以替代 createInCustomDirectory() 方法?
- android - 捕获电子邮件验证单击并打开 android 应用程序
- android - 如何在 MPAndroidChart 中显示 StackedBar 的标签和值?
- javascript - 为什么 display === "none" 没有按预期给出结果
- lisp - 为什么 LET 在我的 quine 中不能很好地与 EVAL 配合使用?
- android - 为什么我的 SharedPreferences 在按钮单击时不保存?