python - 如果手动使用 BCE 损失,所有梯度值计算为“无”
问题描述
我正在研究一个多输出模型,我需要在计算整体损失之前权衡所有输出损失。我有一个定制的model. fit()
训练循环来实现这一点。
由于我需要计算所有四个输出的样本损失并在应用权重后融合这些样本损失,因此我自定义了标准代码。现在,损失是按样本计算的,但是在计算梯度时,所有梯度值都计算为“无”。我也试着放了tape.watch(loss)
,但它不起作用。请帮我解决这个问题。
class CustomModel(keras.Model):
def train_step(self, data):
print(tf.executing_eagerly())
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
alpha = 0.1
loss = 0
y_pred_all = []
with tf.GradientTape() as tape:
bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
for spl in range(1 if np.shape(x)[0] == None else np.shape(x)[0]):
tape.watch(loss)
tape.watch(loss_mean)
tape.watch(loss_element)
x_spl = np.reshape(x[spl], (1, np.shape(x)[1], np.shape(x)[2], np.shape(x)[3]))
y_pred = self(x_spl, training=True) # Forward pass
y_pred_all.append(y_pred)
loss_element = bce(y[spl], y_pred)
loss_mean = [np.mean(loss_element[0]), np.mean(loss_element[1]), np.mean(loss_element[2]), np.mean(loss_element[3])]
id = np.argmin(loss_mean)
for i, ele in enumerate(loss_mean):
if i == id:
loss_mean[i] *= 1
else:
loss_mean[i] *= alpha
loss = loss + np.sum(loss_mean)
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred_all)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
更新 我按照@rvinas的建议做了一些更改 现在它正在计算梯度而没有任何错误,但我不确定我所做的更改是否正确:
class CustomModel(keras.Model):
def train_step(self, data):
# print(tf.executing_eagerly())
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
alpha = 0.1
loss = tf.Variable(0, dtype='float32')
y_pred_all = []
with tf.GradientTape() as tape:
bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
for spl in tf.range(1 if tf.shape(x)[0] == None else tf.shape(x)[0]):
loss_mean=tf.convert_to_tensor([])
x_spl = tf.reshape(x[spl], (1, tf.shape(x)[1], tf.shape(x)[2], tf.shape(x)[3]))
y_pred = self(x_spl, training=True) # Forward pass
y_pred_all.append(y_pred)
loss_element = bce(y[spl], y_pred)
loss_mean = [tf.reduce_mean(loss_element[0]), tf.reduce_mean(loss_element[1]), tf.reduce_mean(loss_element[2]), tf.reduce_mean(loss_element[3])]
id = tf.argmin(loss_mean)
for i, ele in enumerate(loss_mean):
if i == id:
loss_mean[i] = tf.multiply(loss_mean[i], 1)
else:
loss_mean[i] = tf.multiply(loss_mean[i], alpha)
loss = tf.add(loss, tf.add(tf.add(tf.add(loss_mean[0],loss_mean[1]), loss_mean[2]), loss_mean[3]))
# Compute gradients
trainable_vars = self.trainable_variables
gradients = tape.gradient(loss, trainable_vars)
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred_all)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
解决方案
NaN 梯度的出现是因为您正在使用 NumPy 操作(例如np.sum
, np.reshape
, ...),这会导致图形断开连接。相反,只需要使用 tensorflow 操作来实现逻辑。
例如,可以实现评论部分中描述的权重,如下所示:
bce = tf.keras.losses.BinaryCrossentropy(reduction=tf.keras.losses.Reduction.NONE)
with tf.GradientTape() as tape:
# Compute element-wise losses
y_pred = self(x, training=True)
losses = bce(y, y_pred) # Shape=(bs, 4)
# Find maximum loss for each sample
idx_max = tf.argmax(losses, axis=-1) # Shape=(bs,)
idx_max_onehot = tf.one_hot(idx_max, depth=y.shape[-1]) # Shape=(bs, 4)
# Create weights tensor
weight_max = 1
weight_others = 0.1
weights = idx_max_onehot * weight_max + (1 - idx_max_onehot) * weight_others
# Aggregate losses
losses = tf.reduce_sum(weights * losses, axis=-1)
loss = tf.reduce_mean(losses)
推荐阅读
- openvas - 如何使用 Openvas+proxychains 或 Nmap+proxychains(通过 socks5 而不是 socks4)?
- php - 错误!找不到命令“tesseract”。(PHP 蒂亚戈莱西奥)
- firebase-authentication - Firebase 身份验证是否存储散列和加盐密码?
- angular - ngFor Angular 2+ 中未加载索引
- r - 如何编写引用 R 优化中的目标的约束
- javascript - 过滤对象的数组属性而不改变对象本身
- webpack - webpack 正在使用父节点的 node_modules 并且当前目录中有 node_modules
- cmake - 傻瓜的 CMake 3.5.0 详细模式
- angularjs - 自定义指令 - 在模板中非法使用 ngTransclude 指令
- sql - 子查询与聚合函数混合