python - 如何用另一个网络的权重初始化一个网络的权重?
问题描述
我想将 2 个网络合并为一个网络,同时保持原始网络的权重。
我使用以下方法以 numpy 形式保存了权重:
for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
weights[i.name] = i.eval()
我找不到将权重加载到新网络变量中的方法。有没有办法将权重加载到所有变量?
我尝试了以下但得到错误:
for i in tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES):
i.initializer = weights[i.name]
错误:
AttributeError: can't set attribute
解决方案
您可以编写这两个函数
def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
return {v.name: sess.run(v) for v in tf.global_variables()}
def load_from_dict(sess, data):
for v in tf.global_variables():
if v.name in data.keys():
sess.run(v.assign(data[v.name]))
诀窍是简单地遍历所有变量并检查它们是否存在于字典中,比如
import tensorflow as tf
import numpy as np
def save_to_dict(sess, collection=tf.GraphKeys.TRAINABLE_VARIABLES):
return {v.name: sess.run(v) for v in tf.global_variables()}
def load_from_dict(sess, data):
for v in tf.global_variables():
if v.name in data.keys():
sess.run(v.assign(data[v.name]))
def network(x):
x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc0')
x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc1')
x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc2')
x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc3')
x = tf.layers.dense(x, 512, activation=tf.nn.relu, name='fc4')
return x
element = np.random.randn(8, 10)
weights = None
# first session
with tf.Session() as sess:
x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
y = network(x)
sess.run(tf.global_variables_initializer())
# first evaluation
expected = sess.run(y, {x: element})
# dump as dict
weights = save_to_dict(sess)
# destroy session and graph
tf.reset_default_graph()
# second session
with tf.Session() as sess:
x = tf.placeholder(dtype=tf.float32, shape=[None, 10])
y = network(x)
sess.run(tf.global_variables_initializer())
# use randomly initialized parameters
actual = sess.run(y, {x: element})
assert np.sum(np.abs(actual - expected)) > 0 # should NOT match
# load previous parameters
load_from_dict(sess, weights)
actual = sess.run(y, {x: element})
assert np.sum(np.abs(actual - expected)) == 0 # should match
这样,您可以简单地从字典中删除一些参数,在加载之前更改权重,甚至更改参数名称。
推荐阅读
- java - 手机睡着时如何继续在 Android 上获取 GPS 位置
- google-bigquery - 寻找 BigQuery 标准 SQL 教程、示例、书籍的新手,
- php - 具有匿名函数闭包的策略模式
- python - Odoo API 写入方法
- html - html slim link_to 未定义方法错误
- python - mpl_finance 无法将 -100000 转换为日期
- bash - LSF 错误:项目必须是 'acc_*'
- c++ - Clang 不会看到标题
- ios - Swift JSON响应两个字段值附加到单个数组中
- ios - 通过 Button 以编程方式将 UITableViewController 替换为 UIViewController 内的 tableView