python - 使用 numpy 进行 MXNet 参数序列化
问题描述
我想在 s390x 架构上使用预训练的 MXNet 模型,但它似乎不起作用。这是因为预训练的模型是小端的,而 s390x 是大端的。所以,我正在尝试使用适用于小端和大端的https://numpy.org/devdocs/reference/generated/numpy.lib.format.html 。
解决这个问题的一种方法是我发现在 x86 机器上加载模型参数,调用 asnumpy,通过 numpy 保存然后使用 numpy 在 s390x 机器上加载参数并将它们转换为 MXNet。但我不确定如何编码。任何人都可以帮我吗?
更新
似乎这个问题不清楚。因此,我添加了一个示例,可以通过 3 个步骤更好地解释我想要做什么 -
- 从 MXNet 加载预先存在的模型,如下所示 -
net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=mx.cpu())
- 导出模型。以下代码将模型参数保存在 .param 文件中。但是这个 .param 二进制文件有字节序问题。因此,我不想使用 mxnet API 直接保存模型,而是想使用 numpy - https://numpy.org/devdocs/reference/generated/numpy.lib.format.html保存参数文件。因为使用 numpy,会使二进制文件(.npy)端独立。我不确定如何将 MXNet 模型的参数转换为 numpy 格式并保存。
gluon.contrib.utils.export(net, path="./my_model")
- 加载模型。以下代码从 .param 文件加载模型。
net = gluon.contrib.utils.import(symbol_file="my_model-symbol.json",
param_file="my_model-0000.params",
ctx = 'cpu')
我想使用 numpy 加载我们在步骤 2 中创建的 .npy 文件,而不是使用 MXNet API 加载。加载 .npy 文件后,我们需要将其转换为 MXNet。所以,我终于可以在 MXNet 中使用该模型了。
解决方案
从另一个问题中发布的代码片段开始,使用 NumPy 保存/加载 MXNet 模型参数:
似乎 mxnet 可以选择在内部将数据存储为 numpy 数组:
mx.npx.set_np(True, True)
不幸的是,这个选项并没有达到我希望的效果(我的 IPython 会话崩溃了)。
参数是一个dict
实例mxnet.gluon.parameter.Parameter
,每个实例都包含其他特殊数据类型的属性。解开它以便您可以将其存储为大量纯 numpy 数组(或它们在.npz
文件中的集合)是一项无望的任务。
幸运的是,python 必须pickle
将复杂的数据结构转换为或多或少可移植的东西:
# (mxnet/resnet setup skipped)
parameters = resnet.collect_params()
import pickle
with open('foo.pkl', 'wb') as f:
pickle.dump(parameters, f)
恢复参数:
with open('foo.pkl', 'rb') as f:
parameters_loaded = pickle.load(f)
本质上,它看起来就像在获取参数(使用)resnet.save_parameters()
中定义的那样,并使用似乎是从 C 编译的自定义写入函数将它们写入文件(我没有检查细节)。mxnet/gluon/block.py
_collect_parameters_with_prefix()
您可以pickle
改为使用保存参数。
对于加载,load_parameters
(也在util.py
)包含此代码(删除了完整性检查):
for name in loaded:
params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)
这里,loaded
是从文件中加载的字典。通过检查代码,我没有完全掌握正在加载的确切内容 -params
似乎是函数中不再使用的局部变量。但值得尝试从这里开始,通过为该load_parameters
函数编写一个替换。您可以通过在类外部定义一个函数来将函数“猴子补丁”到现有类中,如下所示:
def my_load_parameters(self, ...):
... (put your modified implementation here)
mx.gluon.Block.load_parameters = my_load_parameters
免责声明/警告:
- 即使您通过保存/加载
pickle
方式在单个大端系统上工作,也不能保证在不同端系统之间工作。pickle 协议本身是 endian-neutral 的,但是如果浮点值(在内部深处mxnet.gluon.parameter.Parameter
被存储为机器端约定中的原始数据缓冲区,那么 pickle 不会神奇地猜测缓冲区中的 8 个字节组需要被颠倒我认为 numpy 数组在腌制时是字节序安全的。 - 如果底层的类定义在酸洗和解酸之间发生变化,Pickle 就不是很健壮。
- 永远不要解开不受信任的数据。
推荐阅读
- scala - Spark + cassandra - 只插入一条记录
- ffmpeg - ffmpeg 使用 mp4 段将 hls 保存到 m3u8
- sql - 在 SQL 中使用 where 条件子句
- openlayers - geomondrian 的免费 Solap 客户端
- reactjs - webpack 配置没有在构建文件夹中发出 css
- pentaho - 用于检查/取消选中 PDI ETL 元数据注入中的复选框的值类型
- javascript - 获取 spotify api 403 错误。未捕获(承诺)
- java - 获取格式错误的firebase ServerValue.TimeSTAMP,如何转换为日期?
- c# - 如何在asp.net mvc angularjs中将日期时间格式转换为日期
- javascript - 使用输入按钮提交表单