tensorflow - TensroFlow2 和 Keras 中不同的前向和后向传播
问题描述
我在前向传递中训练神经网络,随机一半时间使用不可微分激活,将激活四舍五入为 0 或 1(二进制),另一半使用类似于 Sigmoid 的可微分函数(饱和-确切地说是Sigmoid)然而,在后向传播中,我们使用关于可微函数的梯度,即使我们在前向传播中使用了不可微分的离散函数。我到目前为止的代码是:
diff_active = tf.math.maximum(sat_sigmoid_term1(feature), sat_sigmoid_term2(feature))
binary_masks = diff_active
rand_cond = tf.random.uniform([1,])
cond = tf.constant(rand_cond, shape=[1,])
if cond <0.5:
with tf.GradientTape() as tape:
non_diff_active = tf.grad_pass_through(tf.keras.layers.Lambda(lambda x: tf.where(tf.math.greater(x,0), x, tf.zeros_like(x))))(feature)
grads = tape.gradient(non_diff_active , feature)
binary_masks = non_diff_active
tf.math.multiply(binary_masks, feature)
我的直觉是,通过这种方式,始终应用可微激活(希望它的梯度始终包含在 bacl-prop 中),并且tf.grad_pass_through()
我可以应用不可微激活,同时用单位矩阵替换它的反向传播。但是,我不确定我tf.grad_pass_through()
对随机变量的使用或条件是否正确,以及行为是否符合预期?
解决方案
您可以tf.custom_gradient
为此使用:
import tensorflow as tf
@tf.function
def sigmoid_grad(x):
return tf.gradients(tf.math.sigmoid(x), x)[0]
@tf.custom_gradient
def sigmoid_or_bin(x, rand):
rand = tf.convert_to_tensor(rand)
out = tf.cond(rand > 0.5,
lambda: tf.math.sigmoid(x),
lambda: tf.dtypes.cast(x > 0, x.dtype))
return out, lambda y: (y * sigmoid_grad(x), None)
# Test
tf.random.set_seed(0)
x = tf.random.uniform([4], -1, 1)
tf.print(x)
# [-0.416049719 -0.586867094 0.0707814693 0.122514963]
with tf.GradientTape() as t:
t.watch(x)
y = tf.math.sigmoid(x)
tf.print(y)
# [0.397462428 0.357354015 0.517688 0.530590475]
tf.print(t.gradient(y, x))
# [0.239486054 0.229652107 0.249687135 0.249064222]
with tf.GradientTape() as t:
t.watch(x)
y = sigmoid_or_bin(x, 0.2)
tf.print(y)
# [0 0 1 1]
tf.print(t.gradient(y, x))
# [0.239486054 0.229652107 0.249687135 0.249064222]
with tf.GradientTape() as t:
t.watch(x)
y = sigmoid_or_bin(x, 0.8)
tf.print(y)
# [0.397462428 0.357354015 0.517688 0.530590475]
tf.print(t.gradient(y, x))
# [0.239486054 0.229652107 0.249687135 0.249064222]
推荐阅读
- ruby-on-rails - 如何将 Tailwind 2.0 安装到现有项目 (rails 6.1)
- matplotlib - Matplotlib 自定义 bbox 样式与 Underline + Facecolor 可能吗?
- excel - VBA 将 .XLSX 文件转换为批量 .CSV 时出现问题
- r - 如何在 dagitty 图中将变量圈到观察(非潜在)变量
- python - PySpark。spark.yarn.executor.memoryOverhead 和 spark.executor.pyspark.memory 之间的相关性
- python - 使用 PyQt5,如何在布局的矩形区域上添加阴影效果
- c# - TextWriter:需要从数据库中获取数据到excel
- python - 使用python将字符串中的一些单词大写
- java - Hibernate getOne() 不使用后台线程
- python-3.x - 使用 Boto3 基于标签过滤器停止所有区域中的 RDS 实例的 Lambda 函数