python - 使用 CNN 模型进行文本分类的预测脚本出错
问题描述
我正在尝试为教程编写脚本的预测部分: https ://mxnet.incubator.apache.org/tutorials/nlp/cnn.html
import mxnet as mx
from collections import Counter
import os
import re
import threading
import sys
import itertools
import numpy as np
from collections import namedtuple
SENTENCES_DIR = 'C:/code/mxnet/sentences'
CURRENT_DIR = 'C:/code/mxnet'
def clean_str(string):
string = re.sub(r"[^A-Za-z0-9(),!?\'\`]", " ", string)
string = re.sub(r"\'s", " \'s", string)
string = re.sub(r"\'ve", " \'ve", string)
string = re.sub(r"n\'t", " n\'t", string)
string = re.sub(r"\'re", " \'re", string)
string = re.sub(r"\'d", " \'d", string)
string = re.sub(r"\'ll", " \'ll", string)
string = re.sub(r",", " , ", string)
string = re.sub(r"!", " ! ", string)
string = re.sub(r"\(", " \( ", string)
string = re.sub(r"\)", " \) ", string)
string = re.sub(r"\?", " \? ", string)
string = re.sub(r"\s{2,}", " ", string)
return string.strip().lower()
def load_data_sentences(filename):
sentences_file = open( filename, "r")
# Tokenize
x_text = [line.decode('Latin1').strip() for line in sentences_file.readlines()]
x_text = [clean_str(sent).split(" ") for sent in x_text]
return x_text
def pad_sentences(sentences, padding_word=""):"
sequence_length = max(len(x) for x in sentences)
padded_sentences = []
for i in range(len(sentences)):
sentence = sentences[i]
num_padding = sequence_length - len(sentence)
new_sentence = sentence + [padding_word] * num_padding
padded_sentences.append(new_sentence)
return padded_sentences
def build_vocab(sentences):
word_counts = Counter(itertools.chain(*sentences))
vocabulary_inv = [x[0] for x in word_counts.most_common()]
vocabulary = {x: i for i, x in enumerate(vocabulary_inv)}
return vocabulary, vocabulary_inv
def build_input_data(sentences, vocabulary):
x = np.array([
[vocabulary[word] for word in sentence]
for sentence in sentences])
return x
def predict(mod, sen):
mod.forward(Batch(data=[mx.nd.array(sen)]))
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1]
for i in a[0:5]:
print('probability=%f' %(prob[i]))
sentences = load_data_sentences( os.path.join( SENTENCES_DIR, 'test-pos-1.txt') )
sentences_padded = pad_sentences(sentences)
vocabulary, vocabulary_inv = build_vocab(sentences_padded)
x = build_input_data(sentences_padded, vocabulary)
Batch = namedtuple('Batch', ['data'])
sym, arg_params, aux_params = mx.model.load_checkpoint( os.path.join( CURRENT_DIR, 'cnn'), 19)
mod = mx.mod.Module(symbol=sym, context=mx.cpu(), label_names = None)
mod.bind(for_training=False, data_shapes=[('data', (50,56))], label_shapes=mod._label_shapes)
mod.set_params(arg_params, aux_params, allow_missing=True)
predict(mod, x)
但我得到了错误:
infer_shape 错误。参数:数据:(50, 26L) Traceback(最近一次调用最后一次):文件“C:\code\mxnet\test2.py”,第 152 行,在 predict(mod, x) 文件“C:\code\mxnet\ test2.py",第 123 行,在预测 mod.forward(Batch(data=[mx.nd.array(sen)])) ...
MXNetError:运算符 reshape0 中的错误:[16:20:21] c:\projects\mxnet-distro-win\mxnet-build\src\operator\tensor./matrix_op-inl.h:187:检查失败:oshape.Size () == dshape.Size() (840000 vs. 390000) 目标形状大小与源不同。目标:[50,1,56,300] 来源:[50,26,300]
源是包含 50 个句子的文本文件
不幸的是,我在 Internet 上没有找到任何帮助。请看一下。操作系统:Windows 10。Python 2.7 谢谢。
解决方案
我相信您遇到的错误是因为您输入句子的填充与模型期望的不同。pad_sentences 的工作方式是将句子填充到传入的最长句子的长度,因此,如果您使用不同的数据集,您几乎肯定会得到与模型的填充(即 56)不同的填充。在这种情况下,您似乎得到了 26 的填充(来自错误消息“来源:[50,26,300]”)。
通过如下修改 pad_sentence 并使用 sequence_length=56 运行它以匹配模型,我能够让您的代码成功运行。
def pad_sentences(sentences, sequence_length, padding_word=""):
padded_sentences = []
for i in range(len(sentences)):
sentence = sentences[i]
num_padding = sequence_length - len(sentence)
new_sentence = sentence + [padding_word] * num_padding
padded_sentences.append(new_sentence)
return padded_sentences
注意,当您成功运行时,您会遇到错误,因为 prob[i] 不是浮点数。
def predict(mod, sen):
mod.forward(Batch(data=[mx.nd.array(sen)]))
prob = mod.get_outputs()[0].asnumpy()
prob = np.squeeze(prob)
a = np.argsort(prob)[::-1]
for i in a[0:5]:
print('probability=%f' %(prob[i])) << prob is a numpy.ndarray, not a float.
维沙尔
推荐阅读
- r - R:返回所有简单路径的边列表
- netsuite - 每种语言使用两个不同的搜索
- python - Google Places API 是否提供了一种获取 EV 端口和可用性数据的方法?
- python - Pandas:确定列是否匹配
- vue.js - Vue 如何在我创建的自定义组件中自动向我的 q-input 添加属性?
- java - Gremlin 返回不好的结果
- java - Java 命令行注释选项
- php - Laravel HTTP 客户端 - 将 XML 文件发布为 application/octet-stream
- spring - 没有fluent api的Spring Cloud Gateway自定义路由
- mysql - 如何修复此错误才能使用 bootrun