首页 > 解决方案 > 如何在 sagemaker 自定义部署端点脚本中加载文件

问题描述

我正在尝试将 sagemaker 上的情绪分析模型部署到端点,以实时预测输入文本的情绪。该模型将单个文本字符串作为输入并返回情绪。

为了训练 xgboost 模型,我按照这个笔记本直到第 23 步。这将 model.tar.gz 上传到 s3 存储桶。我还将sklearn的CountVectorizer(创建词袋)生成的词汇表也上传到了s3桶。

为了部署这个预训练模型,我可以使用这个方法并提供一个入口点 python 文件 predict.py。

sklearn_model = SKLearnModel(model_data="s3://bucket/model.tar.gz", role="SageMakerRole", entry_point="predict.py")

文档说我必须仅提供 model.tar.gz 作为参数,它将被加载到 model_fn 中。但是如果我正在编写自己的model_fn,那么我该如何加载模型呢?如果我将其他文件放在与 S3 中的 model.tar.gz 相同的目录中,我也可以加载它们吗?

现在要进行分类,我必须在调用方法 predict_fn 中的 model.predict(bow_vector) 之前对输入文本进行矢量化。为此,我需要在预处理训练数据期间准备并写入 s3 的 word_dict。

我的问题是如何在 model_fn 中获取 word_dict?我可以从 s3 加载它吗?下面是 predict.py 的代码。

import os
import re
import pickle
import numpy as np
import pandas as pd
import nltk
nltk.download("stopwords")
from nltk.corpus import stopwords
from nltk.stem.porter import *
from bs4 import BeautifulSoup
import sagemaker_containers

from sklearn.feature_extraction.text import CountVectorizer



def model_fn(model_dir):

    #TODO How to load the word_dict.
    #TODO How to load the model.
    return model, word_dict

def predict_fn(input_data, model):
    print('Inferring sentiment of input data.')
    trained_model, word_dict = model
    if word_dict is None:
        raise Exception('Model has not been loaded properly, no word_dict.')

    #Process input_data so that it is ready to be sent to our model.

    input_bow_csv = process_input_text(word_dict, input_data)
    prediction = trained_model.predict(input_bow_csv)
    return prediction


def process_input_text(word_dict, input_data):

    words = text_to_words(input_data);
    vectorizer = CountVectorizer(preprocessor=lambda x: x, tokenizer=lambda x: x, word_dict)
    bow_array = vectorizer.transform([words]).toarray()[0]
    bow_csv = ",".join(str(bit) for bit in bow_array)
    return bow_csv

def text_to_words(text):
    """
    Uses the Porter Stemmer to stem words in a review
    """
    #instantiate stemmer
    stemmer = PorterStemmer()
    text_nohtml = BeautifulSoup(text, "html.parser").get_text() # Remove HTML tags
    text_lower = re.sub(r"[^a-zA-Z0-9]", " ", text_nohtml.lower()) # Convert to lower case
    words = text_lower.split() # Split string into words
    words = [w for w in words if w not in stopwords.words("english")] # Remove stopwords
    words = [PorterStemmer().stem(w) for w in words] # stem
    return words

def input_fn(input_data, content_type):
    return input_data;

def output_fn(prediction_output, accept):
    return prediction_output;

标签: pythonamazon-web-servicesscikit-learnaws-sdkamazon-sagemaker

解决方案


推荐阅读