tensorflow - 为什么 get_tensor_by_name 无法正确获取 tf.keras.layers 定义的层权重
问题描述
我尝试tf.keras.layers
通过使用get_tensor_by_name
in获得层的权重tensorflow
。代码如下
# encoding: utf-8
import tensorflow as tf
x = tf.placeholder(tf.float32, (None,3))
h = tf.keras.layers.dense(3)(x)
y = tf.keras.layers.dense(1)(h)
for tn in tf.trainable_variables():
print(tn.name)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
w = tf.get_default_graph().get_tensor_by_name("dense/kernel:0")
print(sess.run(w))
重量的名称是dense/kernel:0
。但是,输出sess.run(w)
很奇怪
[( 10,) ( 44,) ( 47,) (106,) (111,) ( 98,) ( 58,) (108,) (111,) ( 99,)
( 97,) (108,) (104,) (111,) (115,) (116,) ( 47,) (114,) (101,)
... ]
这不是一个浮点数组。事实上,如果我用它tf.layers.dense
来定义网络,一切都很好。所以我的问题是如何tf.keras.layers
通过正确使用张量名称来获得定义的层的权重。
解决方案
您可以get_weights()
在图层上使用来获取特定图层的权重值。这是您的案例的示例代码:
import tensorflow as tf
input_x = tf.placeholder(tf.float32, [None, 3], name='x')
dense1 = tf.keras.Dense(3, activation='relu')
l1 = dense1(input_x)
dense2 = tf.keras.Dense(1)
y = dense2(l1)
weights = dense1.get_weights()
可以使用 Keras API 以更简单的方式完成,如下所示:
def mymodel():
i = Input(shape=(3, ))
x = Dense(3, activation='relu')(i)
o = Dense(1)(x)
model = Model(input=i, output=o)
return model
model = mymodel()
names = [weight.name for layer in model.layers for weight in layer.weights]
weights = model.get_weights()
for name, weight in zip(names, weights):
print(name, weight.shape)
此示例获取模型每一层的权重矩阵。
推荐阅读
- twilio - 通过 SMS 自动化发送唯一代码
- android - android studio 应用程序文件夹丢失。如何再次显示?
- asp.net-core - 如何对 Azure AD 中特定区域的用户进行身份验证和授权?
- mysql - 在 MySQL 源上启用 CDC 时,AWS DMS“不支持或注释掉 DDL”
- database - 基于名称类别在 ms 访问中创建新列
- c - 矩阵乘法函数不会在主函数内编译
- python - matplotlib.pyplot 不显示图形?
- javascript - 如何将对象映射到所有输入元素?
- python - Python Tkinter 循环使用每个新列表值更新 3 个标签
- php - 检查结帐页面,如果登录用户已经购买了产品