首页 > 解决方案 > 如何对 Python 中的消息分类任务执行多项朴素贝叶斯网格搜索?

问题描述

我想知道我们如何使用多项式朴素贝叶斯分类器进行网格搜索?

这是我的多项分类器:

import numpy as np
from collections import Counter
from sklearn.grid_search import GridSearchCV 
from nltk.corpus import stopwords
from nltk.tokenize import word_tokenize
from nltk.text import Text
from nltk.stem import WordNetLemmatizer
from nltk.stem import PorterStemmer
from nltk.tokenize import RegexpTokenizer
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.naive_bayes import MultinomialNB
from sklearn import metrics

train = np.load("/kaggle/input/ift3395-ift6390-reddit-comments/data_train.pkl",allow_pickle=True)
test=np.load("/kaggle/input/ift3395-ift6390-reddit-comments/data_test.pkl",allow_pickle=True)


original_train= train[0]
originaly_train=train[1]

original_train[70000:]=test

label_to_number_dict= {key: n for n, key in enumerate(set(originaly_train[0:70000]))}
number_to_label_dict = {v: k for k, v in label_to_number_dict.items()}

y_train_as_number=[0]*len(originaly_train)
for i in range(len(originaly_train)):
    y_train_as_number[i]=label_to_number_dict.get(originaly_train[i])
ps = PorterStemmer()
lem = WordNetLemmatizer()

token = RegexpTokenizer(r'[a-zA-Z0-9]+')
tf=TfidfVectorizer(lowercase=True,stop_words='english',ngram_range = (1,1),tokenizer = token.tokenize)
text_tf_train= tf.fit_transform([lem.lemmatize(ps.stem(x)) for x in original_train])

X_train =text_tf_train[0:70000]
X_test = text_tf_train[70000:]

clf = MultinomialNB().fit(X_train, y_train_as_number)
predicted= clf.predict(X_test)
print("MultinomialNB Accuracy:",metrics.accuracy_score(temp_y_label, predicted))

y_predicted=[""]*len(predicted)
for i in range(len(predicted)):
    y_predicted[i]=number_to_label_dict.get(predicted[i])
    #print(y_predicted[i])

我在想这样的事情:

from sklearn.model_selection import GridSearchCV
parameters = {  
'alpha': (1, 0.1, 0.01, 0.001, 0.0001, 0.00001)  
}  
grid_search= GridSearchCV(clf, parameters)
grid_search.fit(X_train,y_train_as_number)

它给了我一个错误说:GridSearchCV is not defined

所以,我该如何解决它是我的第一个问题。然后我还想搜索数据处理部分的参数。我该怎么做?

标签: pythonscikit-learn

解决方案


推荐阅读