python - 在 TensorFlow2 中结合来自不同“网络”的梯度
问题描述
我正在尝试将一些“网络”组合成一个最终的损失函数。我想知道我所做的是否“合法”,截至目前我似乎无法完成这项工作。我正在使用张量流概率:
主要问题在这里:
# Get gradients of the loss wrt the weights.
gradients = tape.gradient(loss, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])
# Update the weights of our linear layer.
optimizer.apply_gradients(zip(gradients, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])
这给了我无渐变并抛出应用渐变:
AttributeError:“列表”对象没有属性“设备”
完整代码:
univariate_gmm = tfp.distributions.MixtureSameFamily(
mixture_distribution=tfp.distributions.Categorical(probs=phis_true),
components_distribution=tfp.distributions.Normal(loc=mus_true,scale=sigmas_true)
)
x = univariate_gmm.sample(n_samples, seed=random_seed).numpy()
dataset = tf.data.Dataset.from_tensor_slices(x)
dataset = dataset.shuffle(buffer_size=1024).batch(64)
m_phis = keras.layers.Dense(2, activation=tf.nn.softmax)
m_mus = keras.layers.Dense(2)
m_sigmas = keras.layers.Dense(2, activation=tf.nn.softplus)
def neg_log_likelihood(y, phis, mus, sigmas):
a = tfp.distributions.Normal(loc=mus[0],scale=sigmas[0]).prob(y)
b = tfp.distributions.Normal(loc=mus[1],scale=sigmas[1]).prob(y)
c = np.log(phis[0]*a + phis[1]*b)
return tf.reduce_sum(-c, axis=-1)
# Instantiate a logistic loss function that expects integer targets.
loss_fn = neg_log_likelihood
# Instantiate an optimizer.
optimizer = tf.keras.optimizers.SGD(learning_rate=1e-3)
# Iterate over the batches of the dataset.
for step, y in enumerate(dataset):
yy = np.expand_dims(y, axis=1)
# Open a GradientTape.
with tf.GradientTape() as tape:
# Forward pass.
phis = m_phis(yy)
mus = m_mus(yy)
sigmas = m_sigmas(yy)
# Loss value for this batch.
loss = loss_fn(yy, phis, mus, sigmas)
# Get gradients of the loss wrt the weights.
gradients = tape.gradient(loss, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights])
# Update the weights of our linear layer.
optimizer.apply_gradients(zip(gradients, [m_phis.trainable_weights, m_mus.trainable_weights, m_sigmas.trainable_weights]))
# Logging.
if step % 100 == 0:
print("Step:", step, "Loss:", float(loss))
解决方案
有两个不同的问题需要考虑。
1.梯度是None
:
通常,如果在GradientTape
. 具体来说,这涉及到函数np.log
中的计算。neg_log_likelihood
如果你用 替换np.log
,tf.math.log
梯度应该计算。尽量不要在“内部”张量流组件中使用 numpy 可能是一个好习惯,因为这样可以避免这样的错误。对于大多数 numpy 操作,有一个很好的 tensorflow 替代品。
2.apply_gradients
对于多个可训练对象:
这主要与apply_gradients
期望的输入有关。你有两个选择:
第一种选择:调用apply_gradients
3 次,每次使用不同的可训练对象
optimizer.apply_gradients(zip(m_phis_gradients, m_phis.trainable_weights))
optimizer.apply_gradients(zip(m_mus_gradients, m_mus.trainable_weights))
optimizer.apply_gradients(zip(m_sigmas_gradients, m_sigmas.trainable_weights))
另一种方法是创建一个元组列表,如tensorflow 文档中所示(引用:“grads_and_vars:(梯度,变量)对列表。”)。这意味着调用类似的东西
optimizer.apply_gradients(
[
zip(m_phis_gradients, m_phis.trainable_weights),
zip(m_mus_gradients, m_mus.trainable_weights),
zip(m_sigmas_gradients, m_sigmas.trainable_weights),
]
)
这两个选项都要求您拆分渐变。您可以通过计算梯度并分别索引它们来做到这一点(gradients[0],...
),或者您可以简单地单独计算梯度。请注意,这可能需要persistent=True
在您的GradientTape
.
# [...]
# Open a GradientTape.
with tf.GradientTape(persistent=True) as tape:
# Forward pass.
phis = m_phis(yy)
mus = m_mus(yy)
sigmas = m_sigmas(yy)
# Loss value for this batch.
loss = loss_fn(yy, phis, mus, sigmas)
# Get gradients of the loss wrt the weights.
m_phis_gradients = tape.gradient(loss, m_phis.trainable_weights)
m_mus_gradients = tape.gradient(loss, m_mus.trainable_weights)
m_sigmas_gradients = tape.gradient(loss, m_sigmas .trainable_weights)
# Update the weights of our linear layer.
optimizer.apply_gradients(
[
zip(m_phis_gradients, m_phis.trainable_weights),
zip(m_mus_gradients, m_mus.trainable_weights),
zip(m_sigmas_gradients, m_sigmas.trainable_weights),
]
)
# [...]
推荐阅读
- docker - 如何有条件地拉取 Docker 镜像的最新标签,而不是使用缓存版本?
- cv2 - 如何消除错误:“ImportError:DLL 加载失败:找不到指定的模块。” 在 jupyter notebook 中导入 cv2 时
- c# - Page.OnAppearing 中的 Xamarin.Forms Page.DisplayAlert
- c# - 处理 Blazor 时,在 CefSharp WinForms 中处理自定义标题的最佳方法是什么?
- testing - 它是运行 infinispan-server 10 的新 arquillian 容器吗?
- sql-server - SSIS派生列从字符串中删除文本,只保留数值
- intellij-idea - 包结构的 IntelliJ 设置
- xml - 如何使用 XPath 提取部分文本并将它们作为键值对放入 woocommerce 属性中?
- r - 在 R 中查找重复数据的样本均值
- html - 滑入式菜单的 CSS onclick 事件(无 JS)