python - 在 Tensorflow 2 中手动设置 trainable_variables 权重
问题描述
我必须在 tensorflow 的模型中设置 trainable_variables 值,而不是使用优化器。有没有功能或方法可以做到这一点?我展示了一个示例代码:我想设置 mnist_model.trainable_variables 值。
for epoch in range(0,1):
with tf.GradientTape() as tape:
prediction = mnist_model(mnist_images, training=True)
loss_value = loss(mnist_labels, prediction)
variables = mnist_model.trainable_variables
loss_history.append(loss_value.numpy())
grads = tape.gradient(loss_value, variables)
解决方案
model.trainable_variables
返回可训练变量的列表。当你把它们打印出来时,你会看到它们的形状。
<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 16) dtype=float32>
使用此形状,您可以使用该方法分配权重.assign()
。在你这样做之前你需要build()
你的模型,否则 Tensorflow 将没有可训练的变量。
model.trainable_variables[0].assign(tf.fill((3, 3, 1, 16), .12345))
Out[3]:
<tf.Variable 'conv2d/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345]],
[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345]]
完整的工作示例:
import tensorflow as tf
from tensorflow.keras import Model
from tensorflow.keras.layers import Dense, Conv2D, MaxPool2D, Dropout, Flatten
class CNN(Model):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = Conv2D(filters=16, kernel_size=(3, 3), strides=(1, 1))
self.maxp1 = MaxPool2D(pool_size=(2, 2))
self.flat1 = Flatten()
self.dens1 = Dense(64, activation='relu')
self.drop1 = Dropout(5e-1)
self.dens3 = Dense(10)
def call(self, x, training=None, **kwargs):
x = self.conv1(x)
x = self.maxp1(x)
x = self.flat1(x)
x = self.dens1(x)
x = self.drop1(x)
x = self.dens3(x)
return x
model = CNN()
model.build(input_shape=(1, 28, 28, 1))
print(model.trainable_variables[0])
model.trainable_variables[0].assign(tf.fill((3, 3, 1, 16), .12345))
print(model.trainable_variables[0])
原始重量:
<tf.Variable 'conv2d_2/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[-0.18103004, -0.18038717, -0.04171562, -0.14022854,
-0.00918788, 0.07348467, 0.07931305, -0.03991133,
0.12809007, -0.11934308, 0.11453925, 0.02502337,
-0.165835 , -0.14841306, 0.1911544 , -0.09917622]],
[[-0.0496967 , 0.13865136, -0.17599788, -0.18716624,
-0.03473145, -0.02006209, -0.00364855, -0.03497578,
0.05207129, 0.07728194, -0.11234754, 0.09303482,
0.17245303, -0.07428543, -0.19278058, 0.15201278]]]],
dtype=float32)>
编辑权重:
<tf.Variable 'conv2d_6/kernel:0' shape=(3, 3, 1, 16) dtype=float32, numpy=
array([[[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345]],
[[0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345, 0.12345, 0.12345,
0.12345, 0.12345, 0.12345, 0.12345]]]], dtype=float32)>
推荐阅读
- javascript - 解决无法集中读取未定义的属性“propertyName”
- sql - 在VB6中我们需要使用哪些adodb记录集选项来在短时间内从SQL中加载数据
- sql - 将记录链接到其他人
- c# - 如何用星号打印星号图案
- ios - 如何在点击的光线触摸而不是来自中心视图的光线处生成 Reality Composer 场景
- apache - 确定原因并修复 302 重定向
- word2vec - 为什么“[UNK]”这个词在 word2vec 词汇表中排在第一位?
- typescript - 打字稿(tsc)显示编译时间
- reactjs - 如何在反应多轮播中居中图像?
- reactjs - 类型错误:Object(...) 不是函数 React js