python - 如何以 CNN 中使用的类似格式将权重保存在 .npy 文件中?
问题描述
我正在使用一个 github 存储库,其中包含经过训练的 CNN,其权重参数在.npy
文件中给出。模型正在加载权重并使用如下模型参数:-
model = CNN_Model(batch_size)
filename = "weight_file.npy"
dtype = torch.FloatTensor
model.load_state_dict(load_weights(model, weight_file, dtype))
并load_weights
定义为:-
def load_weights(model, filename, dtype):
model_params = model.state_dict()
data_dict = np.load(filename, encoding='latin1').item()
model_params["conv1.weight"] = torch.from_numpy(data_dict["conv1"] ["weights"]).type(dtype).permute(3,2,0,1)
model_params["conv1.bias"] = torch.from_numpy(data_dict["conv1"]["biases"]).type(dtype)
model_params["bn1.weight"] = torch.from_numpy(data_dict["bn_conv1"]["scale"]).type(dtype)
model_params["bn1.bias"] = torch.from_numpy(data_dict["bn_conv1"]["offset"]).type(dtype)
return model_params
我已经添加了一个训练模块并尝试微调我自己的数据集的权重。训练后,我想将新的权重保存在与之前加载的权重.npy
文件中相同的索引的文件中,这样我就可以再次将它们用于 CNN 模型。data_dict
在使用以下方法保存 data_dict 数组之前,我应该如何使用相似名称进行索引:
np.save("trained_weight_file.npy", data_dict)
编辑 1:- 所以在@ad 的推荐下我做了
data_dict = model.state_dict()
它所做的是它保存了索引为 的所有权重model_params
。的输出print data_dict
是: -
OrderedDict([('conv1.weight', tensor([[[[....]]]])), ('conv1.bias', tensor([....])), , ('bn1.weight', tensor([....])), ('bn1.bias', tensor([....]))])
但我需要的是存储在data_dict
索引中,这样我就可以使用相同的算法从.npy
文件中读取它。我也尝试从定义中返回data_dict
,然后尝试使用,但它在 `model.load_state_dict(load_weights(model, weight_file, dtype))' 行上给了我错误,即:-model_params
load_weights
data_dict = model.state_dict()
回溯(最后一次调用):model.load_state_dict(load_weights(model, weight_file, dtype)) state_dict = state_dict.copy() AttributeError: 'tuple' object has no attribute 'copy'
解决方案
我会做类似的事情data_dict = model.state_dict()
。
state_dict()
您可以阅读官方文档以及此处的输出示例。有一个github 存储库,它是 github 存储库的基础,您可以从中获取代码。此存储库也用于model.state_dict()
存储值。
推荐阅读
- jmx - 如何在同一进程中使用 Spring 读取 jmx mbean
- reactjs - 如何在 AOR 中使用 Loopback 的关系,如 hasAndBelongsToMany(REST 上的管理员)
- python - Pandas Dataframe 替换系列中的值
- c# - 在多宿主 Windows 10 计算机上接收 UDP 多播消息
- docker - 未知:访问被拒绝:频道 [] creator org [Org1MSP] - docker-compose 中的 Hyperledger
- office365 - Office 加载项中的拖放功能
- python - 将 django 过滤器编译成变量并在运行时执行?
- javascript - 在不刷新网页的情况下为进度条设置动画
- sql - Sql 触发器在可以引用之前更新表
- string - 在 Alteryx 中将填充字符串转换为固定十进制