python - 包含 keras 模型的 deepcopy 类
问题描述
在我的 python 脚本中,我创建了一个类,其中包含keras
如下模型:
from keras.layers import Input, Activation, Dense
from keras.models import Model
class Klass:
def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.optimizer = optimizer
self.a = a
self.b = b
self.__build_nn()
def __build_nn(self):
inputs = Input(shape=(self.input_dims,))
net = inputs
for h_dim in self.hidden_dims:
net = Dense(h_dim, kernel_initializer='he_uniform')(net)
net = Activation("relu")(net)
outputs = Dense(self.output_dims)(net)
outputs = Activation("linear")(outputs)
self.nn1 = Model(inputs=inputs, outputs=outputs)
self.nn2 = Model(inputs=inputs, outputs=outputs)
self.nn1.compile(optimizer=self.optimizer, loss='mean_squared_error')
self.nn2.compile(optimizer=self.optimizer, loss='mean_squared_error')
创建Klass
实例后,我想对其进行深层复制:
import copy
obj = Klass(10, 10, (20, 20), Adam(), 1, 2)
obj_dc = copy.deepcopy(obj)
但是,这会引发TypeError: can't pickle _thread.RLock objects
. 我很确定该错误与keras
类对象中的模型有关,因为我能够在没有keras
模型的情况下获得类似类的深层副本。
不幸的是,我无法在互联网上找到解决方案,因为大多数关于深度复制keras
模型的问题都试图克隆keras
像这里这样的模型。
那么,如何获得包含keras
模型的类的深层副本?
编辑
这三个问题(1、2、3)在不同情况下都提到了类似的错误。然而,那里提供的解决方案不适用于我的情况。
编辑 2
正如评论中所建议的,我copy
在类中添加了一个方法。那会是一个可行的解决方案吗?
class Klass:
def __init__(self, input_dims, output_dims, hidden_dims, optimizer, a, b):
self.input_dims = input_dims
self.output_dims = output_dims
self.hidden_dims = hidden_dims
self.optimizer = optimizer
self.a = a
self.b = b
self.__build_nn()
# [...]
def copy(self):
new = Klass(self.input_dims, self.output_dims, self.hidden_dims,
self.optimizer, self.a, self.b)
new.nn1.set_weights(self.nn1.get_weights())
new.nn2.set_weights(self.nn2.get_weights())
return new
解决方案
在评论中解决:添加了copy
一种Klass
将权重从旧Klass
实例复制到新创建的实例的方法。
推荐阅读
- java - 使用字段值更新同一文档中的另一个字段
- .net-core - .NET Core 不支持 Nuget 包 Microsoft.AspNetCore.* 版本 5
- python - 如何打开一个目录(文件夹)然后在目录中打开一个文件 - discord.py
- c - BMI 计算程序的问题 - 如果 else 和 dev C++ 中的计算错误
- r - 减少 FluidRow 之间的空间
- python - 如何从 Colaboratory 中保存的检查点加载 TensorFlow Keras 模型?
- spring - 无法使用 PF4J、PF4J-Spring 和 Wicket 从应用程序上下文接收具体 bean
- javascript - 有没有人可以帮助我进行联系表单验证?
- r - 在 R 中,有没有办法通过除第 n 列以外的所有列的条件过滤小标题(我们不知道列数)?
- r - 我如何有一个实心条形图来表示数量的比率?