python - Tensorflow 2 + Keras 的知识蒸馏损失
问题描述
我正在尝试实现一个非常简单的 keras 模型,该模型使用来自另一个模型的知识蒸馏 [1]。粗略地说,我需要用另一个模型的预测L(y_true, y_pred)
来L(y_true, y_pred)+L(y_teacher_pred, y_pred)
替换原始损失。y_teacher_pred
我试过做
def create_student_model_with_distillation(teacher_model):
inp = tf.keras.layers.Input(shape=(21,))
model = tf.keras.models.Sequential()
model.add(inp)
model.add(...)
model.add(tf.keras.layers.Dense(units=1))
teacher_pred = teacher_model(inp)
def my_loss(y_true,y_pred):
loss = tf.keras.losses.mean_squared_error(y_true, y_pred)
loss += tf.keras.losses.mean_squared_error(teacher_pred, y_pred)
return loss
model.compile(loss=my_loss, optimizer='adam')
return model
但是,当我尝试调用fit
我的模型时,我得到了
TypeError: An op outside of the function building code is being passed
a "Graph" tensor. It is possible to have Graph tensors
leak out of the function building context by including a
tf.init_scope in your function building code.
我该如何解决这个问题?
参考文献
解决方案
实际上,这篇博文是对您问题的回答:keras blog
但简而言之 - 您应该使用新的 TF2 API 并在块predict
之前调用教师:tf.GradientTape()
def train_step(self, data):
# Unpack data
x, y = data
# Forward pass of teacher
teacher_predictions = self.teacher(x, training=False)
with tf.GradientTape() as tape:
# Forward pass of student
student_predictions = self.student(x, training=True)
# Compute losses
student_loss = self.student_loss_fn(y, student_predictions)
distillation_loss = self.distillation_loss_fn(
tf.nn.softmax(teacher_predictions / self.temperature, axis=1),
tf.nn.softmax(student_predictions / self.temperature, axis=1),
)
loss = self.alpha * student_loss + (1 - self.alpha) * distillation_loss
推荐阅读
- sql - 执行 sp_executesql 时 where 子句中的参数错误
- javascript - 控制器没有被读取角度js
- json - 如何在 gcm 推送通知有效负载中获取 gcm.notification.body 字段
- c# - 通过 REST api 在谷歌日历中添加范围
- javascript - ` 中的重音符号
.vue.html` 显示为 � - git - 以前的 git revert 正在删除新合并中的文件
- java - 在springboot应用的application.properties中配置map
- java - 拆分字符串后出现 NumberFormatException
- vba - 如何在工作表之间复制在vba中,我试图录制宏但它不起作用
- javascript - 输入点击区域大于输入,导致按钮不起作用