首页 > 解决方案 > 从 S3 存储桶加载 pytorch 模型

问题描述

我想model.pt从 S3 存储桶加载 pytorch 模型 ()。我写了以下代码:

from smart_open import open as smart_open
import io

load_path = "s3://serial-no-images/yolo-models/model4/model.pt"
with smart_open(load_path) as f:
    buffer = io.BytesIO(f.read())
    model.load_state_dict(torch.load(buffer))

这会导致以下错误:

UnicodeDecodeError: 'utf-8' codec can't decode byte 0x80 in position 64: invalid start byte

一种解决方案是在本地下载模型,但我想避免这种情况并直接从 S3 加载模型。不幸的是,我在网上找不到一个好的解决方案。有人可以帮我吗?

标签: amazon-s3pytorchtorch

解决方案


AFAIKtorch.load期望文件名作为参数 - 而不是文件的内容。你buffer的结果是否可能已经等同于torch.loading 文件的本地副本?
如果您尝试这样做会发生什么model.load_state_dict(buffer)


推荐阅读