python - 没有急切执行的张量流中的反复损失
问题描述
我有以下非常简单的损失示例(这可能没有意义)
import tensorflow as tf
class Loss:
def __init__(self):
self.last_output = tf.constant([0.5,0.5])
def recurrent_loss(self, model_output):
now = 0.9*self.last_output + 0.1*model_output
self.last_output = now
return tf.reduce_mean(now)
仅评估reduced_mean
model_output 与最后一个model_output 的组合(比例为 9 比 1)。所以例如
>> l = Loss()
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.55, shape=(), dtype=float32)
>> l.recurrent_loss(tf.constant([1.,1.]))
tf.Tensor(0.595, shape=(), dtype=float32)
如果我正确理解 tf 是如何工作的,这只是可能的,因为默认情况下 tf 正在急切地执行(tf.executing_eagerly() == True
)。这应该是我可以用新的张量覆盖 self.last_output 变量以实现循环结构的原因。
我的问题:我怎样才能在不使用急切执行的 tf 图中实现相同的循环结构?
解决方案
在图形模式下,您必须使用仅在第一次执行函数时创建的 tf.Variable,例如:
class Loss:
def __init__(self):
self.last_output = None
@tf.function
def recurrent_loss(self, model_output):
if self.last_output is None:
self.last_output = tf.Variable([0.5,0.5])
now = 0.9*self.last_output + 0.1*model_output
self.last_output.assign(now)
return tf.reduce_mean(now)
推荐阅读
- node.js - 如何监控 NPM 包许可类型的变化?
- android - Android API 21 中的 Webview 为空白/白色
- python - 网页抓取文本返回 0
- css - 获取 Bootstrap 5 网格行以垂直跨越其整个内容
- javascript - 一个函数可以只属于它声明的块作用域吗?
- javascript - OnChange 后反应表单字段更新更新
- reactjs - 我的默认道具有一个带有数据集的对象,我该如何设置数据?
- api - 如何通过rest api获取Bamboo中的linkedrepository信息?
- python - 在 Keras 中使用自定义损失函数进行模型训练时出错
- c++ - OpenGL GLFW + GLAD程序不绘制三角形