python - Tensorflow 2变量不可训练
问题描述
我在 tf2 中创建了一个简单的模型,它将输入“a”乘以变量“b”(初始化为 1)并返回输出“c”。然后我尝试在简单的数据集 a=1, c=5 上对其进行训练。我希望它能够学习 b=5。
import tensorflow as tf
from tensorflow.keras.models import Model
a = Input(shape=(1,))
b = tf.Variable(1., trainable=True)
c = a*b
model = Model(a,c)
loss = tf.keras.losses.MeanAbsoluteError()
model.compile(optimizer='adam', loss=loss)
model.fit([1.],[5.],batch_size=1, epochs=1)
但是,tf2 并不认为变量“b”是可训练的。摘要显示没有可训练的参数。
Model: "model"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
input_1 (InputLayer) [(None, 1)] 0
_________________________________________________________________
tf_op_layer_mul (TensorFlowO [(None, 1)] 0
=================================================================
Total params: 0
Trainable params: 0
Non-trainable params: 0
_________________________________________________________________
为什么变量'b'没有训练?
解决方案
Keras 模型是Layer类的包装器。您必须将此变量包装为 keras 层,以便将其显示为模型中的可训练参数。
您可以像这样创建一个微小的自定义层:
class MyLayer(tf.keras.layers.Layer):
def __init__(self):
super(MyLayer, self).__init__()
#your variable goes here
self.variable = tf.Variable(1., trainable=True, dtype=tf.float64)
def call(self, inputs, **kwargs):
# your mul operation goes here
x = inputs * self.variable
return x
这里call
方法将进行乘法运算。我们可以像使用输出模型中的任何其他层一样使用这一层。在这里,我正在创建一个添加 aboce 乘法运算作为模型层的序列模型。
model = tf.keras.models.Sequential()
mylayer_object = MyLayer()
model.add(mylayer_object)
loss = tf.keras.losses.MeanAbsoluteError()
model.compile("adam", loss)
model.fit([1.],[5.],batch_size=1, epochs=1)
model.summary()
'''
Train on 1 samples
1/1 [==============================] - 0s 426ms/sample - loss: 4.0000
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param #
=================================================================
my_layer (MyLayer) multiple 1
=================================================================
Total params: 1
Trainable params: 1
Non-trainable params: 0
_________________________________________________________________
'''
在此之后,如果您可以列出模型的可训练参数。
print(model.trainable_variables)
# [<tf.Variable 'Variable:0' shape=() dtype=float64, numpy=1.0009999968852092>]
推荐阅读
- java - 如何加入两个数据集
在 Spark Java 中?
- kotlin - 如何从 API 访问列表以显示数据?(科特林)
- firebase - 无法将 DialogFlow 与 Firestore 集成
- javascript - 如何实现从 FlatList 中删除项目的方法?
- angular - 绑定来自云 Firestore 的 Angular 表单输入字段值
- html - 如何在bootstarp4中使大图像不被绘制到我们想要的高度
- mysql - 像codeigniter中的查询
- sql - 从窗口函数中获取最频繁的值
- cfml - 将 WriteOutput 转换为可用的 CFOutput 变量时出现问题
- python - 在我的 python 代码中,如何捕获导入的请求模块引发的 HTTP 500 异常?