tensorflow - 在 keras 密集层设置预训练的权重
问题描述
我在使用 Keras 工作时遇到了一个可怕的问题,这是一个问题;
encoder_path = "my own encoder path"
if (os.path.exists(encoder_path)):
encoder = tf.keras.models.load_model(encoder_path, compile=False)
# encoder.summary()
print("Encoder model exist & loaded ...")
else:
print("There is no file! Check " + encoder_path + ' ...')
## Making latent vector layer code ##
loc_z_mean = len(encoder.layers) - 11
loc_z_log_var = len(encoder.layers) - 10
z_mean = encoder.layers[loc_z_mean]
z_log_var = encoder.layers[loc_z_log_var]
print(z_mean.get_weights())
z_mean_weights = z_mean.get_weights()[0]
z_mean_bias = z_mean.get_weights()[1]
print(np.shape(z_mean_weights))
print(np.shape(z_mean_bias))
z_log_var_weights = z_log_var.get_weights()[0]
z_log_var_bias = z_log_var.get_weights()[1]
z_weights = z_mean_weights + np.exp(0.5 * z_log_var_weights)
z_bias = z_mean_bias + np.exp(0.5 * z_log_var_bias)
# z_weights_init = z_weights.numpy()
# z_bias_init = z_bias.numpy()
z = tf.keras.layers.Dense(16, name="latent_z").set_weights([z_weights, z_bias])
# z = tf.keras.layers.Dense(16, kernel_initializer=z_weights_init, bias_initializer=z_bias_init, name="latent_z")
# z.trainable = False # Freeze layer
print(z)
我正在尝试从以前的模型中制作一个新的重量。但是尝试时不起作用
z = tf.keras.layers.Dense(16, name="latent_z").set_weights([z_weights, z_bias])
出现此错误;
ValueError: You called `set_weights(weights)` on layer "latent_z" with a weight list of length 2, but the layer was expecting 0 weights. Provided weights: [array([[0.85919297, 0.39330506, 1.4273021 , 0.780...
我分别设置了形状z_weights
和z_bias
大小,(16, 16)
因为(16,)
这些大小与第一次加载的重量完全相同,但它不起作用。
有什么解决办法吗?
提前致谢。
解决方案
您需要为具有所需 input_shape 的层调用 build() 才能设置权重
z = tf.keras.layers.Dense(16, name="latent_z")
# set shape
z.build(input_shape=(100,)
# using random weights and bias for now
bias = np.random.randn(16)
weight = np.random.randn(100,16)
z.set_weights([weight, bias])
如果不调用 build(),则不会定义层权重,因为权重的形状取决于输入的形状。在上面的示例代码中,输入形状为 100,因此权重形状为 [100,16]。
推荐阅读
- javascript - Laravel,无法弄清楚如何使用下拉菜单显示特定年份的主题
- python-3.x - 如何修复出现在我自己定义的函数中的 NameError?
- android - 当声音在两个 ImageButtons 中播放完毕时,将暂停图标更改为播放图标
- bash - 使用 bash 增加 .txt 文件中的变量
- python - 使用链接/按钮以 HTML 格式打开 .py 文件
- java - 为什么 Intellij Heap Memory Leak on Hello World with Thread.Sleep?
- python - Django - 从 HttpResponse 对象中检索 url?
- python - ValueError: int() 以 10 为基数的无效文字:在 python 中导入请求模块时出现“错误”
- forms - 我收到错误:将 admob 添加到我的 xamarin 表单项目时,包 com.google.ads.mediation.customevent 不存在错误
- database - 用于图形数据库的 Sqlite(可嵌入)