keras - 将模型保存在keras中是否有先决条件?
问题描述
我想在训练后保存我的 keras 模型。model.fit 函数有效,但不幸的是 model.save('path') 或 model.save_weights('path') 命令不起作用。
我也尝试使用 pickle 或 np.save 保存模型,但它也不起作用。
我的模型构建如下:
model_resnet = Model(inputs=RESNET.input, outputs=RESNET.output)
model = Sequential()
model.add(model_resnet)
model.add(BatchNormalization())
model.add(Reshape((1,256)))
model.add(Bidirectional(GRU(512,return_sequences=True)))
model.add(Bidirectional(GRU(512)))
model.add(Dense(11,activation='softmax'))
其中 RESNET 是使用 keras 功能 API 定义的 3D resnet32 模型。相同的代码可以这样写:
model_ = Sequential()
model_.add(BatchNormalization())
model_.add(Reshape((1,256)))
model_.add(Bidirectional(GRU(512,return_sequences=True)))
model_.add(Bidirectional(GRU(512)))
model_.add(Dense(11,activation='softmax'))
model = Model(input = RESNET.input, outputs = model_(RESNET.output))
我正在尝试使用以下代码进行保存:
model.save(root_dir+'\\models\\model.h5')
我也试过:
x = model.get_weights()
with open(root_dir+'\\models\\model.pickle', 'wb') as f:
pickle.dump(x, f)
这些方法都不起作用。
使用 keras 保存功能我有以下错误:(不要介意错误中模型名称的名称)
File ".../train.py", line 110, in <module>
model_video.save(root_dir+'\\models\\model_video.h5')
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\keras\engine\network.py", line 1090, in save
save_model(self, filepath, overwrite, include_optimizer)
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\keras\engine\saving.py", line 382, in save_model
_serialize_model(model, f, include_optimizer)
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\keras\engine\saving.py", line 114, in _serialize_model
layer_group[name] = val
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\keras\utils\io_utils.py", line 218, in __setitem__
dataset = self.data.create_dataset(attr, val.shape, dtype=val.dtype)
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\h5py\_hl\group.py", line 136, in create_dataset
dsid = dataset.make_new_dset(self, shape, dtype, data, **kwds)
File "...\anaconda3\envs\tensorflow_env\lib\site-packages\h5py\_hl\dataset.py", line 117, in make_new_dset
dtype = numpy.dtype(dtype)
TypeError: data type not understood
用泡菜我有以下错误:
Traceback (most recent call last):
File ".../train.py", line 113, in <module>
pickle.dump(x, f)
_pickle.PicklingError: Can't pickle <class 'numpy.ndarray'>: it's not the same object as numpy.ndarray
解决方案
“conda install numpy”解决了这个问题。
推荐阅读
- wordpress - 如何修复滑块文本与徽标的对齐方式?
- python - Flake8 linter 不突出 Sublime Text 中的错误
- php - 如何通过wpdb使用数组从数据库中获取结果
- php - 如何使用递归函数获取查询结果?
- javascript - 使用 SIP.js(版本 0.13.7)的多个呼叫
- c - 无法打开或关闭绿色/蓝色 LED STM32F429ZI - Nucleo 板?
- vb.net - 字符串未被识别为有效的 DateTime Datetimepicker vb.net
- game-maker - 从字符串中删除标签
- pyspark - 如何将列表的 RDD 转换为压缩列表的 RDD?
- python-3.x - 运行放置在 PATH (cygwin) 中的 Python 脚本