python-3.x - Keras渐变wrt别的东西
问题描述
我正在努力实现文章https://drive.google.com/file/d/1s-qs-ivo_fJD9BU_tM5RY8Hv-opK4Z-H/view中描述的方法。最终使用的算法在这里(在第 6 页):
- d 是单位向量
- xhi 是一个非空数
- D 是损失函数(在我的例子中是稀疏交叉熵)
这个想法是进行对抗训练,通过在网络对微小变化最敏感的方向修改数据,并使用修改后的数据但使用与原始数据相同的标签来训练网络。
我正在尝试使用 MNIST 数据集和 100 个小批量数据在 Keras 中实现此方法,但我无法理解梯度 wrt r 的计算(算法第三步的第一行) . 我不知道如何用 Keras 计算它。这是我的代码:
loss = losses.SparseCategoricalCrossentropy()
for epoch in range(5):
print(f"Start of epoch {epoch}")
for step, (xBatchTrain,yBatchTrain) in enumerate(trainDataset):
#Generating the 100 unit vectors
randomVectors = np.random.random(xBatchTrain.shape)
U = randomVectors / np.linalg.norm(randomVectors,axis=1)[:, None]
#Generating the r vectors
Xi = 2
R = tf.convert_to_tensor(U * Xi[:, None],dtype='float32')
dataNoised = xBatchTrain + R
with tf.GradientTape(persistent=True) as imTape:
imTape.watch(R)
#Geting the losses
C = [loss(label,pred) for label, pred in zip(yBatchTrain,dumbModel(dataNoised,training=False))]
#Getting the gradient wrt r for each images
for l,r in zip(C,R):
print(imTape.gradient(l,r))
" print
" 行为每个样本返回 None。我应该返回一个包含 784 个值的向量,每个值对应一个像素?
(我很抱歉部分代码很丑,我是 Keras、tf 和深度学习的新手)
[编辑]
这是整个笔记本的要点:https ://gist.github.com/DridriLaBastos/136a8e9d02b311e82fe22ec1c2850f78
解决方案
首先,移到dataNoised = xBatchTrain + R
里面with tf.GradientTape(persistent=True) as imTape:
去记录相关的操作R
其次,而不是使用:
for l,r in zip(C,R):
print(imTape.gradient(l,r))
您应该使用imTape.gradient(C,R)
来获取梯度集,因为zip
会破坏 的张量中的操作依赖性,将其R
打印出来将返回类似以下形状的内容xBatchTrain
:
tf.Tensor(
[[-1.4924371e-06 1.0490652e-05 -1.8195267e-05 ... 1.5640746e-05
3.3767541e-05 -2.0983218e-05]
[ 2.3668531e-02 1.9133706e-02 3.1396169e-02 ... -1.4431887e-02
5.3144591e-03 6.2225698e-03]
[ 2.0492254e-03 7.1049971e-04 1.6121448e-03 ... -1.0579333e-03
2.4968456e-03 8.3572773e-04]
...
[-4.5572519e-03 6.2278998e-03 6.8322839e-03 ... -2.1966733e-03
1.0822206e-03 1.8687058e-03]
[-6.3691144e-03 -4.1699030e-02 -9.3158096e-02 ... -2.9496195e-02
-7.0264392e-02 -3.2520775e-02]
[-1.4666058e-02 2.0758331e-02 2.9009990e-02 ... -3.2206681e-02
3.1550713e-02 4.9267178e-03]], shape=(100, 784), dtype=float32)
推荐阅读
- reactjs - 故事书迁移到 CSF 的打字稿问题
- java - Selenium 中的 NullPointerException 与 java
- xslt - 在忽略空白值并在 XSLT 中添加 LineAmount 时删除重复项
- python - Pipenv shell 命令创建新的 venv 而不是加载现有的
- javascript - 在中使用 window.history.replaceState() 是否安全?
- javascript - 如何从不是数据库中的列的实体中返回额外的字段?
- node.js - 获取“被 CORS 策略阻止:请求的资源上不存在‘Access-Control-Allow-Origin’标头。” 使用 Axios 使用 MERN 堆栈
- database-design - 如何将自定义日志存储到数据库
- flutter - 解码飞镖中的音频缓冲区
- sql - 使用来自jsonb的milis将间隔添加到时间戳