首页 > 解决方案 > AWS Sagemaker 对原始图像输入的自定义 PyTorch 模型推断

问题描述

我是 AWS Sagemaker 的新手。我在本地有自定义 CV PyTorch 模型并将其部署到 Sagemaker 端点。我使用自定义inference.py代码来定义 model_fn、input_fn、output_fn 和 predict_fn 方法。因此,我能够对 json 输入生成预测,其中包含图像的 url,代码非常简单:

def input_fn(request_body, content_type='application/json'):

    logging.info('Deserializing the input data...')

    image_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    if content_type:
        if content_type == 'application/json':
            input_data = json.loads(request_body)
            url = input_data['url']
            logging.info(f'Image url: {url}')
            image_data = Image.open(requests.get(url, stream=True).raw)

        return image_transform(image_data)
    raise Exception(f'Requested unsupported ContentType in content_type {content_type}')

然后我可以使用代码调用端点:

client = boto3.client('runtime.sagemaker')
inp = {"url":url}
inp = json.loads(json.dumps(inp))
 
response = client.invoke_endpoint(EndpointName='ENDPOINT_NAME',
                                  Body=json.dumps(inp),
                                  ContentType='application/json')

我看到的问题是,与 Sagemaker 上的图像数组相比,本地 url 请求返回的图像数组略有不同。这就是为什么在同一个 URL 上我得到的预测略有不同的原因。为了检查至少模型权重是否相同,我想对图像本身生成预测,在本地下载并下载到 Sagemaker。但我未能尝试将图像作为输入到端点。例如:

def input_fn(request_body, content_type='application/json'):

    logging.info('Deserializing the input data...')

    image_transform = transforms.Compose([
            transforms.Resize(size=(224, 224)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    if content_type == 'application/x-image':
        image_data = request_body

        return image_transform(image_data)
    raise Exception(f'Requested unsupported ContentType in content_type {content_type}')

调用端点我遇到错误:

ParamValidationError:参数验证失败:参数正文的类型无效,值:{'img':<PIL.JpegImagePlugin.JpegImageFile 图像模式=RGB size=630x326 at 0x7F78A61461D0>},类型:<class 'dict'>,有效类型:< class 'bytes'>, <class 'bytearray'>, 类文件对象

有人知道如何通过 Pytorch 模型对图像生成 Sagemaker 预测吗?

标签: imagepytorchamazon-sagemakerinference

解决方案


与往常一样,在询问后我找到了解决方案。实际上,正如错误提示的那样,我必须将输入转换为字节或字节数组。对于那些可能需要解决方案的人:

from io import BytesIO

img = Image.open(open(PATH, 'rb'))
img_byte_arr = BytesIO()
img.save(img_byte_arr, format=img.format)
img_byte_arr = img_byte_arr.getvalue()

client = boto3.client('runtime.sagemaker')
 
response = client.invoke_endpoint(EndpointName='ENDPOINT_NAME
                                  Body=img_byte_arr,
                                  ContentType='application/x-image')
response_body = response['Body'] 
print(response_body.read())


推荐阅读