首页 > 解决方案 > 为什么我不能使用字节或文件作为有效负载调用 sagemaker 端点

问题描述

我在 Sagemaker 上部署了一个线性回归模型。现在我想编写一个 lambda 函数来对输入数据进行预测。首先从 S3 中提取文件。一些预处理已经完成,最终的输入是一个pandas数据框。根据 boto3 sagemaker 文档,有效负载可以是类似字节的,也可以是文件。所以我尝试使用这篇文章中的代码将数据帧转换为字节数组

# Convert pandas dataframe to byte array
pred_np = pred_df.to_records(index=False)
pred_str = pred_np.tostring()

# Start sagemaker prediction
sm_runtime = aws_session.client('runtime.sagemaker')
response = sm_runtime.invoke_endpoint(
    EndpointName=SAGEMAKER_ENDPOINT,
    Body=pred_str,
    ContentType='text/csv',
    Accept='Accept')

我打印出pred_str对我来说确实像一个字节数组。 在此处输入图像描述

但是,当我运行它时,我得到以下Algorithm Error原因UnicodeDecodeError

Caused by: 'utf8' codec can't decode byte 0xed in position 9: invalid continuation byte

回溯显示 python 2.7 不确定为什么会这样:

Traceback (most recent call last):
  File "/opt/amazon/lib/python2.7/site-packages/ai_algorithms_sdk/serve.py", line 465, in invocations
    data_iter = get_data_iterator(payload, **content_parameters)
  File "/opt/amazon/lib/python2.7/site-packages/ai_algorithms_sdk/io/serve_helpers.py", line 99, in iterator_csv_dense_rank_2
    payload = payload.decode("utf8")
  File "/opt/amazon/python2.7/lib/python2.7/encodings/utf_8.py", line 16, in decode
    return codecs.utf_8_decode(input, errors, True)

是默认解码器utf_8吗?我应该使用什么正确的解码器?为什么抱怨位置 9?

此外,我还尝试将数据帧保存到 csv 文件并将其用作有效负载

pred_df.to_csv('pred.csv', index=False)
with open('pred.csv', 'rb') as f:
    payload = f.read()
response = sm_runtime.invoke_endpoint(
    EndpointName=SAGEMAKER_ENDPOINT,
    Body=payload,
    ContentType='text/csv',
    Accept='Accept')

但是,当我运行它时,出现以下错误:

Customer Error: Unable to parse payload. Some rows may have more columns than others and/or non-numeric values may be present in the csv data.

再一次,回溯正在调用 python 2.7:

Traceback (most recent call last):
  File "/opt/amazon/lib/python2.7/site-packages/ai_algorithms_sdk/serve.py", line 465, in invocations
    data_iter = get_data_iterator(payload, **content_parameters)
  File "/opt/amazon/lib/python2.7/site-packages/ai_algorithms_sdk/io/serve_helpers.py", line 123, in iterator_csv_dense_rank_2

这根本没有意义,因为它是标准的 6x78 数据帧。所有行都有相同的列数。另外,没有一列是非数字的。 在此处输入图像描述 如何解决这个 sagemaker 问题?

标签: arrayspandasutf-8boto3amazon-sagemaker

解决方案


我终于能够使用以下代码使其工作:

payload = io.StringIO()
pred_df.to_csv(payload, header=None, index=None)

sm_runtime = aws_session.client('runtime.sagemaker')
response = sm_runtime.invoke_endpoint(
    EndpointName=SAGEMAKER_ENDPOINT,
    Body=payload.getvalue(),
    ContentType='text/csv',
    Accept='Accept')

getvalue()在调用端点时调用有效载荷的函数非常重要。希望这可以帮助


推荐阅读