首页 > 解决方案 > 使用 numpy 进行 MXNet 参数序列化

问题描述

我想在 s390x 架构上使用预训练的 MXNet 模型,但它似乎不起作用。这是因为预训练的模型是小端的,而 s390x 是大端的。所以,我正在尝试使用适用于小端和大端的https://numpy.org/devdocs/reference/generated/numpy.lib.format.html 。

解决这个问题的一种方法是我发现在 x86 机器上加载模型参数,调用 asnumpy,通过 numpy 保存然后使用 numpy 在 s390x 机器上加载参数并将它们转换为 MXNet。但我不确定如何编码。任何人都可以帮我吗?

更新

似乎这个问题不清楚。因此,我添加了一个示例,可以通过 3 个步骤更好地解释我想要做什么 -

  1. 从 MXNet 加载预先存在的模型,如下所示 -
net = mx.gluon.model_zoo.vision.resnet18_v1(pretrained=True, ctx=mx.cpu())
  1. 导出模型。以下代码将模型参数保存在 .param 文件中。但是这个 .param 二进制文件有字节序问题。因此,我不想使用 mxnet API 直接保存模型,而是想使用 numpy - https://numpy.org/devdocs/reference/generated/numpy.lib.format.html保存参数文件。因为使用 numpy,会使二进制文件(.npy)端独立。我不确定如何将 MXNet 模型的参数转换为 numpy 格式并保存。
gluon.contrib.utils.export(net, path="./my_model")
  1. 加载模型。以下代码从 .param 文件加载模型。
net = gluon.contrib.utils.import(symbol_file="my_model-symbol.json",
                                     param_file="my_model-0000.params",
                                     ctx = 'cpu')

我想使用 numpy 加载我们在步骤 2 中创建的 .npy 文件,而不是使用 MXNet API 加载。加载 .npy 文件后,我们需要将其转换为 MXNet。所以,我终于可以在 MXNet 中使用该模型了。

标签: pythonnumpymxnets390x

解决方案


从另一个问题中发布的代码片段开始,使用 NumPy 保存/加载 MXNet 模型参数

似乎 mxnet 可以选择在内部将数据存储为 numpy 数组:

mx.npx.set_np(True, True)

不幸的是,这个选项并没有达到我希望的效果(我的 IPython 会话崩溃了)。

参数是一个dict实例mxnet.gluon.parameter.Parameter,每个实例都包含其他特殊数据类型的属性。解开它以便您可以将其存储为大量纯 numpy 数组(或它们在.npz文件中的集合)是一项无望的任务。

幸运的是,python 必须pickle将复杂的数据结构转换为或多或少可移植的东西:

# (mxnet/resnet setup skipped)
parameters = resnet.collect_params()

import pickle
with open('foo.pkl', 'wb') as f:
    pickle.dump(parameters, f)

恢复参数:

with open('foo.pkl', 'rb') as f:
    parameters_loaded = pickle.load(f)

本质上,它看起来就像在获取参数(使用)resnet.save_parameters()中定义的那样,并使用似乎是从 C 编译的自定义写入函数将它们写入文件(我没有检查细节)。mxnet/gluon/block.py_collect_parameters_with_prefix()

您可以pickle改为使用保存参数。

对于加载,load_parameters(也在util.py)包含此代码(删除了完整性检查):

for name in loaded:
    params[name]._load_init(loaded[name], ctx, cast_dtype=cast_dtype, dtype_source=dtype_source)

这里,loaded是从文件中加载的字典。通过检查代码,我没有完全掌握正在加载的确切内容 -params似乎是函数中不再使用的局部变量。但值得尝试从这里开始,通过为该load_parameters函数编写一个替换。您可以通过在类外部定义一个函数来将函数“猴子补丁”到现有类中,如下所示:

def my_load_parameters(self, ...):
   ... (put your modified implementation here)

mx.gluon.Block.load_parameters = my_load_parameters

免责声明/警告:

  • 即使您通过保存/加载pickle方式在单个大端系统上工作,也不能保证在不同端系统之间工作。pickle 协议本身是 endian-neutral 的,但是如果浮点值(在内部深处mxnet.gluon.parameter.Parameter被存储为机器端约定中的原始数据缓冲区,那么 pickle 不会神奇地猜测缓冲区中的 8 个字节组需要被颠倒我认为 numpy 数组在腌制时是字节序安全的。
  • 如果底层的类定义在酸洗和解酸之间发生变化,Pickle 就不是很健壮。
  • 永远不要解开不受信任的数据。

推荐阅读