image - AWS Sagemaker 对原始图像输入的自定义 PyTorch 模型推断
问题描述
我是 AWS Sagemaker 的新手。我在本地有自定义 CV PyTorch 模型并将其部署到 Sagemaker 端点。我使用自定义inference.py
代码来定义 model_fn、input_fn、output_fn 和 predict_fn 方法。因此,我能够对 json 输入生成预测,其中包含图像的 url,代码非常简单:
def input_fn(request_body, content_type='application/json'):
logging.info('Deserializing the input data...')
image_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if content_type:
if content_type == 'application/json':
input_data = json.loads(request_body)
url = input_data['url']
logging.info(f'Image url: {url}')
image_data = Image.open(requests.get(url, stream=True).raw)
return image_transform(image_data)
raise Exception(f'Requested unsupported ContentType in content_type {content_type}')
然后我可以使用代码调用端点:
client = boto3.client('runtime.sagemaker')
inp = {"url":url}
inp = json.loads(json.dumps(inp))
response = client.invoke_endpoint(EndpointName='ENDPOINT_NAME',
Body=json.dumps(inp),
ContentType='application/json')
我看到的问题是,与 Sagemaker 上的图像数组相比,本地 url 请求返回的图像数组略有不同。这就是为什么在同一个 URL 上我得到的预测略有不同的原因。为了检查至少模型权重是否相同,我想对图像本身生成预测,在本地下载并下载到 Sagemaker。但我未能尝试将图像作为输入到端点。例如:
def input_fn(request_body, content_type='application/json'):
logging.info('Deserializing the input data...')
image_transform = transforms.Compose([
transforms.Resize(size=(224, 224)),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
if content_type == 'application/x-image':
image_data = request_body
return image_transform(image_data)
raise Exception(f'Requested unsupported ContentType in content_type {content_type}')
调用端点我遇到错误:
ParamValidationError:参数验证失败:参数正文的类型无效,值:{'img':<PIL.JpegImagePlugin.JpegImageFile 图像模式=RGB size=630x326 at 0x7F78A61461D0>},类型:<class 'dict'>,有效类型:< class 'bytes'>, <class 'bytearray'>, 类文件对象
有人知道如何通过 Pytorch 模型对图像生成 Sagemaker 预测吗?
解决方案
与往常一样,在询问后我找到了解决方案。实际上,正如错误提示的那样,我必须将输入转换为字节或字节数组。对于那些可能需要解决方案的人:
from io import BytesIO
img = Image.open(open(PATH, 'rb'))
img_byte_arr = BytesIO()
img.save(img_byte_arr, format=img.format)
img_byte_arr = img_byte_arr.getvalue()
client = boto3.client('runtime.sagemaker')
response = client.invoke_endpoint(EndpointName='ENDPOINT_NAME
Body=img_byte_arr,
ContentType='application/x-image')
response_body = response['Body']
print(response_body.read())
推荐阅读
- matlab - 如何获取第一个非单维的索引和大小?
- javascript - 如何为水平卡片拖动添加触摸事件?
- git - Git 子模块命令在 Windows 10 上非常慢
- mysql - 如何在 SQL 中创建 3 个表(2 对 1)之间的关系表
- css - min-height pushes the content outside of div instead of growing inside div
- apache-spark - Neo4j 认为密码是数据库
- node.js - 部署节点弹性 beanstalk 应用程序时出错
- typescript - 如何检查 Puppeteer 页面当前是否处于导航状态?
- android - 无法重现 android 订阅的帐户保留状态
- matrix - 具有 3D 矩阵的方程