首页 > 解决方案 > Keras渐变wrt别的东西

问题描述

我正在努力实现文章https://drive.google.com/file/d/1s-qs-ivo_fJD9BU_tM5RY8Hv-opK4Z-H/view中描述的方法。最终使用的算法在这里(在第 6 页):

算法1

这个想法是进行对抗训练,通过在网络对微小变化最敏感的方向修改数据,并使用修改后的数据但使用与原始数据相同的标签来训练网络。

我正在尝试使用 MNIST 数据集和 100 个小批量数据在 Keras 中实现此方法,但我无法理解梯度 wrt r 的计算(算法第三步的第一行) . 我不知道如何用 Keras 计算它。这是我的代码:

loss = losses.SparseCategoricalCrossentropy()

for epoch in range(5):
    print(f"Start of epoch {epoch}")
    for step, (xBatchTrain,yBatchTrain) in enumerate(trainDataset):
        #Generating the 100 unit vectors
        randomVectors = np.random.random(xBatchTrain.shape)
        U = randomVectors / np.linalg.norm(randomVectors,axis=1)[:, None]

        #Generating the r vectors
        Xi = 2
        R = tf.convert_to_tensor(U * Xi[:, None],dtype='float32')

        dataNoised = xBatchTrain + R

        with tf.GradientTape(persistent=True) as imTape:
            imTape.watch(R)
            #Geting the losses
            C = [loss(label,pred) for label, pred in zip(yBatchTrain,dumbModel(dataNoised,training=False))]

        #Getting the gradient wrt r for each images
        for l,r in zip(C,R):
            print(imTape.gradient(l,r))

" print" 行为每个样本返回 None。我应该返回一个包含 784 个值的向量,每个值对应一个像素?

(我很抱歉部分代码很丑,我是 Keras、tf 和深度学习的新手)

[编辑]

这是整个笔记本的要点:https ://gist.github.com/DridriLaBastos/136a8e9d02b311e82fe22ec1c2850f78

标签: python-3.xkerastensorflow2.0

解决方案


首先,移到dataNoised = xBatchTrain + R里面with tf.GradientTape(persistent=True) as imTape:去记录相关的操作R

其次,而不是使用:

for l,r in zip(C,R):
    print(imTape.gradient(l,r))

您应该使用imTape.gradient(C,R)来获取梯度集,因为zip会破坏 的张量中的操作依赖性,将其R打印出来将返回类似以下形状的内容xBatchTrain

tf.Tensor(
[[-1.4924371e-06  1.0490652e-05 -1.8195267e-05 ...  1.5640746e-05
   3.3767541e-05 -2.0983218e-05]
 [ 2.3668531e-02  1.9133706e-02  3.1396169e-02 ... -1.4431887e-02
   5.3144591e-03  6.2225698e-03]
 [ 2.0492254e-03  7.1049971e-04  1.6121448e-03 ... -1.0579333e-03
   2.4968456e-03  8.3572773e-04]
 ...
 [-4.5572519e-03  6.2278998e-03  6.8322839e-03 ... -2.1966733e-03
   1.0822206e-03  1.8687058e-03]
 [-6.3691144e-03 -4.1699030e-02 -9.3158096e-02 ... -2.9496195e-02
  -7.0264392e-02 -3.2520775e-02]
 [-1.4666058e-02  2.0758331e-02  2.9009990e-02 ... -3.2206681e-02
   3.1550713e-02  4.9267178e-03]], shape=(100, 784), dtype=float32)

推荐阅读