tensorflow - 使用估计作为特征的多任务学习
问题描述
我有具有 3 个分类头的多任务网络[A, B, C]
。我想使用 headA
的输出作为第一个密集层的输入B and C
。
是否应该为反向传播做一些特别的事情,因为我认为从不B and C
应该流回的梯度A
,因为它已经被计算过并且应该作为常数处理。
有没有人有这样的代码示例?
解决方案
你可以试试:
A_layer = tf.keras.layers.Dense(5)(x)
A_head= tf.keras.layers.Dense(5)(A_layer)
A_logic = tf.keras.layers.Dense(1)(A_head)
A_loss = tf.losses.sigmoid_cross_entropy(A_y,A_logic)
B_layer = tf.keras.layers.Dense(5)(tf.stop_gradient(A_logic))
B_head= tf.keras.layers.Dense(5)(B_layer)
B_logic = tf.keras.layers.Dense(1)(B_head)
B_loss = tf.losses.sigmoid_cross_entropy(B_y,B_logic)
C_layer = tf.keras.layers.Dense(5)(tf.stop_gradient(A_logic))
C_head= tf.keras.layers.Dense(5)(C_layer)
C_logic = tf.keras.layers.Dense(1)(C_head)
C_loss = tf.losses.sigmoid_cross_entropy(C_y,C_logic)
total_loss = A_loss + B_loss + C_loss
train_op = tf.train.AdamOptimizer().minimize(total_loss)
推荐阅读
- javascript - 数据表自定义过滤器可在单击按钮时删除具有重复数据的行
- c# - 将 .net Framework 4.5 MVC 应用程序重定向到 /authorize Microsoft Oauth 2.0 端点以检索授权代码
- android - 当运行 ionic cordova run android 时,命令提示符上显示 tranformclasses 错误
- django - 将生产 Django 服务器 + Postgres 数据库复制到相同的 Linode 服务器作为备份
- arrays - 使用 .each 将数组转换为哈希数组
- python - Python从循环中的数据框创建字典
- javascript - 获取javascript中选定选项的值
- c# - 在默认 OS 查看器应用程序中打开从 CEFSharp 下载的文件
- mysql - 用该字母中的 2 个替换 3+ 个重复字母 - mySQL
- python - 我的按钮操作似乎行为不端。该怎么做才能修复它或让它变得更好?