首页 > 解决方案 > Pickle 文件从头开始执行

问题描述

我对酸洗的概念很陌生。据我了解,在酸洗 Python 对象时,状态会以二进制文件的形式保存到磁盘中,以后可以重新加载。所以我需要澄清一下,当你重新加载pickle文件时,它应该从它保存的状态开始执行。我腌制了一个python文件,但它从头开始执行。为了更清楚,请参阅我的代码。

试用_1.py

import pandas as pd
import numpy as np
import re
import string
punct = string.punctuation

import nltk
nltk.download('punkt')
nltk.download('stopwords')

from nltk.tokenize import word_tokenize

from nltk.corpus import stopwords
stopWords = stopwords.words('english')

from nltk.stem import WordNetLemmatizer
nltk.download('wordnet')
wordnet_lemmatizer = WordNetLemmatizer()

from sklearn.feature_extraction.text import TfidfVectorizer

from sklearn.model_selection import train_test_split

from sklearn.pipeline import Pipeline
from sklearn.preprocessing import FunctionTransformer

from sklearn.linear_model import LogisticRegression
from sklearn.naive_bayes import MultinomialNB
from sklearn.tree import DecisionTreeClassifier
from sklearn.neighbors import KNeighborsClassifier
from sklearn import svm
from sklearn.ensemble import RandomForestClassifier

import pickle


def text_pre_processing(data):
    remove_tags = re.compile('<.*?>|&([a-z0-9]+|#[0-9]{1,6}|#x[0-9a-f]{1,6});')
    corpus = []

    for message in data['MESSAGE']:
        token = str(message).replace('\n', ' ')
        token = token.replace('\t', ' ')
        token = token.replace('©', ' ')
        token = token.replace('/b', ' ')
        token = re.sub('(https?:\/\/)([\w]+.)*', ' ', token)  # to remove url
        token = re.sub('(www.)([\w]+.)*', ' ', token)
        token = re.sub(remove_tags, ' ', token)  # remove html tags
        token = "".join([word for word in token if word not in punct])
        token = re.sub('([\d])*', '', token)  # remove numbers
        token = token.lower()
        token = word_tokenize(token)
        token = " ".join(
            [wordnet_lemmatizer.lemmatize(word) for word in token if not word in set(stopWords)])
        corpus.append(token)

    return corpus


class CompareModels():

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)
        data = pd.read_csv('data\\input.csv')
        print(data.iloc[:2, 1:2])


        text_transformer = FunctionTransformer(text_pre_processing)

        pipeline_lr = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('LRClassifier', LogisticRegression(solver='saga', max_iter=1500, random_state=0))
        ]
        )

        pipeline_nb = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('NaiveBayes', MultinomialNB())
        ]
        )

        pipeline_dt = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('DecisionTree', DecisionTreeClassifier())
        ]
        )

        pipeline_knn = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('KNNClassifier', KNeighborsClassifier())
        ]
        )

        pipeline_svm = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('SVM', svm.SVC())
        ]
        )

        pipeline_rfc = Pipeline([
            ('pre-processing', text_transformer),
            ('tfidf', TfidfVectorizer()),
            ('Random Forest Classifier', RandomForestClassifier(n_estimators=400, random_state=0))
        ]
        )

        pipelines = [pipeline_lr, pipeline_nb, pipeline_dt, pipeline_knn, pipeline_svm, pipeline_rfc]

        pipe_dict = {0: 'Logistic Regression',
                     1: 'Naive Bayes',
                     2: 'Decision Tree',
                     3: 'K Neighbor Classifier',
                     4: 'Support Vector Classifier',
                     5: 'Random Forest Classifier'
                     }

        print("Splitting input and output")
        X = data.iloc[:, 1:2]
        y = data.iloc[:, 0:1]

        print("Splitting as training and testing data")
        xtrain, xtest, ytrain, ytest = train_test_split(X, y, test_size=0.3, random_state=42)

        print("Training the models")
        models = []
        for i in pipelines:
            print(i)
            models.append(i.fit(xtrain, np.ravel(ytrain)))

        accuracy = []
        for j, model in enumerate(pipelines):
            print("{} Test Accuracy : {}".format(pipe_dict[j], model.score(xtest, ytest) * 100))
            accuracy.append(model.score(xtest, ytest) * 100)

        index = accuracy.index(max(accuracy))

        self.best_model = models[index]


app = CompareModels()
model = app.best_model
print(model)

pickle.dump(model, open("model.pkl", "wb"))

试用_2.py

import tkinter as tk
from tkinter import filedialog
import pickle
from trial_1 import text_pre_processing
import pandas as pd

class SpamEmailPredictionApp(tk.Tk):

    def __init__(self, *args, **kwargs):

        super().__init__(*args, **kwargs)
        self.geometry("1500x750")

        main_frame = tk.Frame(self, width=200, height=50, highlightbackground="black",
                              highlightthickness=1,background='white')
        main_frame.pack(side='top', fill='both', expand='True')

        main_frame.grid_rowconfigure(0, weight=1)
        main_frame.grid_columnconfigure(0, weight=1)

        # Display image on a Label widget.
        img = tk.PhotoImage(file='images\\4.png')
        lbl = tk.Label(main_frame, image=img)
        lbl.img = img  # Keep a reference in case this code put is in a function.
        lbl.place(relx=0.5, rely=0.5, anchor='center')  # Place label in center of parent.

        window = tk.Frame(main_frame,width = 200,height=200, highlightbackground = 'black',highlightthickness=1,background='white')
        window.grid(row =0, column = 0)

        title = tk.Label(window,text = 'Spam Email Prediction',font=("Helvetica", 40),bg='white')
        title.grid(row=0,column=0)

        submit_button = tk.Button(window, text="Browse File", font=("Helvetica", 20), bg="white",
                                  command=lambda: browse_file())
        submit_button.grid(row=5, column=0, padx=10, pady=10)

        def browse_file():

            self.filename = filedialog.askopenfilename(
                initialdir="C:\\Users\\write\\PycharmProjects\\EmailSpamPrediction\\data",
                title='Choose a File',
                filetypes = (("csv files","*.csv"),("all files","*.*"))
            )
            print('Unpickling started')
            model = pickle.load(open("model.pkl", "rb"))
            test_email = pd.read_csv(self.filename)
            test_email.rename(columns={'Body': 'MESSAGE'}, inplace=True)
            print('Prediction started')
            prediction = model.predict(test_email)
            if(prediction):
                print('Spam')
                text = tk.Label(window,text="Spam", font=("Helvetica", 30),fg ='red')
                text.grid(row=10,column=0,padx=10,pady=10)
            else:
                print('Ham')
                text = tk.Label(window, text="Ham", font=("Helvetica", 30),fg = 'green')
                text.grid(row=10, column=0, padx=10, pady=10)

app = SpamEmailPredictionApp()
app.mainloop()

trial_1.py 是要腌制的文件,trial_2.py 是 gui 文件。

标签: python-3.xtkinterpycharmpickle

解决方案


推荐阅读