python - “请将您的损失计算包装在零参数`lambda`中”是什么意思?方法?
问题描述
我编写了一个将 L2 损失添加到主要损失函数的代码:
def add_l2(model, penalty=0.001):
for layer in model.layers:
if "conv" in layer.name:
model.add_loss(penalty * tf.reduce_sum(tf.square(layer.trainable_variables[0])))
return
## training
@tf.function
def train_one_step(model, x, y, optimizer):
with tf.GradientTape() as tape:
logits = model(x, training=True)
loss = _criterion(y_true=y, y_pred=logits)
add_l2(model, 0.001)
loss += sum(model.losses)
grads = tape.gradient(loss, model.trainable_variables)
optimizer.apply_gradients(zip(grads, model.trainable_variables))
return loss, logits
当我开始训练时,出现如下错误:
ValueError : 期望一个符号张量或一个可调用的损失值。请将您的损失计算包装在零参数
lambda
中。
这个错误是什么意思?我该如何治疗?
解决方案
您的损失引用了模型层之一的变量 ( layer.trainable_variables[0]
),因此需要将您的损失包装在一个零参数的 lambda 中,以使其可调用。
model.add_loss(lambda: penalty * tf.reduce_sum(tf.square(layer.trainable_variables[0])
有关更多详细信息,请在官方文档中查看这里
推荐阅读
- python - 为什么scrapy不抓取我的链接以进行提取
- r - 在 ggplot 中使用 geom_function()
- c++ - 静态铸造工作和动态铸造段错误
- reactjs - 如何声明一个 TypeScript 属性来接受 React 类组件?
- c++ - c++ std::basic_string 模板类实现背后的逻辑是什么?
- python - 得到 KeyError 1,不知道出了什么问题
- java - “Java 运行时环境内存不足,无法继续” - maven build - Intelllij
- flutter - Flutter:如何为我的 PaginatedDataTable 获取数据?
- flutter - 有条件地使用 useFuture
- ruby-on-rails - 我的“不允许的参数::passenger”Rails 错误来自哪里?