python - 注意力实现的预测对齐中的差异化问题
问题描述
我正在尝试根据这篇论文实现 local-p attention:https ://arxiv.org/pdf/1508.04025.pdf具体来说,等式(9)基于采用一些非线性函数的 sigmoid 导出对齐位置,并且然后将结果乘以时间步数。由于 sigmoid 返回的值介于 0 和 1 之间,因此这种乘法会产生介于 0 和时间步数之间的有效索引。我可以对其进行软舍入以推断预测位置,但是,由于 tf.cast() 不可微分,因此我找不到将其转换为整数以在切片/索引操作中使用的方法。另一个问题是派生位置的形状为 (B, 1),因此批次中的每个示例都有一个对齐的位置。请参阅下文以了解这些操作:
"""B = batch size, S = sequence length (num. timesteps), V = vocabulary size, H = number of hidden dimensions"""
class LocalAttention(Layer):
def __init__(self, size, window_width=None, **kwargs):
super(LocalAttention, self).__init__(**kwargs)
self.size = size
self.window_width = window_width # 2*D
def build(self, input_shape):
self.W_p = Dense(units=input_shape[2], use_bias=False)
self.W_p.build(input_shape=(None, None, input_shape[2])) # (B, 1, H)
self._trainable_weights += self.W_p.trainable_weights
self.v_p = Dense(units=1, use_bias=False)
self.v_p.build(input_shape=(None, None, input_shape[2])) # (B, 1, H)
self._trainable_weights += self.v_p.trainable_weights
super(Attention, self).build(input_shape)
def call(self, inputs):
sequence_length = inputs.shape[1]
## Get h_t, the current (target) hidden state ##
target_hidden_state = Lambda(function=lambda x: x[:, -1, :])(inputs) # (B, H)
## Get h_s, source hidden states ##
aligned_position = self.W_p(target_hidden_state) # (B, H)
aligned_position = Activation('tanh')(aligned_position) # (B, H)
aligned_position = self.v_p(aligned_position) # (B, 1)
aligned_position = Activation('sigmoid')(aligned_position) # (B, 1)
aligned_position = aligned_position * sequence_length # (B, 1)
假设aligned_position
张量具有元素 [24.2, 15.1, 12.3] 以简化批量大小 = B = 3。然后,源隐藏状态是从输入隐藏状态 (B=3, S, H) 推导出来的,因此对于第一个示例,我们从 24 开始采取时间步长,因此类似于first_batch_states = Lambda(function=lambda x: x[:, 24:, :])(inputs)
等。注意local-p attention的实现比这个要复杂一些,这里我简化了。因此,主要挑战是将 24.2 转换为 24 而不会失去可微性,或者使用某种掩码操作通过点积获取索引。掩码操作是首选,因为我们必须为每个示例批量执行此操作,并且在自定义 Keras 层内有一个循环并不整洁。您对如何完成此任务有任何想法吗?我将不胜感激任何答案和评论!
解决方案
我发现有两种方法可以解决这个问题。
- 将基于原始问题中显示的对齐位置的高斯分布应用于注意力权重,使过程可区分,正如@Siddhant 建议的那样:
gaussian_estimation = lambda s: tf.exp(-tf.square(s - aligned_position) /
(2 * tf.square(self.window_width / 2)))
gaussian_factor = gaussian_estimation(0)
for i in range(1, sequence_length):
gaussian_factor = Concatenate()([gaussian_factor, gaussian_estimation(i)])
# Adjust weights via gaussian_factor: (B, S*) to allow differentiability
attention_weights = attention_weights * gaussian_factor # (B, S*)
需要注意的是,这里不涉及硬切片操作,只是根据距离进行简单的调整。
- 按照@Vlad 的建议,保持前 n 个值并将其余值清零,如何实现仅保留前 n 个值并将其余所有值清零的自定义 keras 层?:
aligned_position = self.W_p(inputs) # (B, S, H)
aligned_position = Activation('tanh')(aligned_position) # (B, S, H)
aligned_position = self.v_p(aligned_position) # (B, S, 1)
aligned_position = Activation('sigmoid')(aligned_position) # (B, S, 1)
## Only keep top D values out of the sigmoid activation, and zero-out the rest ##
aligned_position = tf.squeeze(aligned_position, axis=-1) # (B, S)
top_probabilities = tf.nn.top_k(input=aligned_position,
k=self.window_width,
sorted=False) # (values:(B, D), indices:(B, D))
onehot_vector = tf.one_hot(indices=top_probabilities.indices,
depth=sequence_length) # (B, D, S)
onehot_vector = tf.reduce_sum(onehot_vector, axis=1) # (B, S)
aligned_position = Multiply()([aligned_position, onehot_vector]) # (B, S)
aligned_position = tf.expand_dims(aligned_position, axis=-1) # (B, S, 1)
source_hidden_states = Multiply()([inputs, aligned_position]) # (B, S*=S(D), H)
## Scale back-to approximately original hidden state values ##
aligned_position += 1 # (B, S, 1)
source_hidden_states /= aligned_position # (B, S*=S(D), H)
应该注意的是,这里我们将密集层应用于所有隐藏的源状态,以获得形状(B,S,1)
而不是(B,1)
for aligned_position
。我相信这与我们可以得到的论文所建议的一样接近。
任何试图实现注意力机制的人都可以查看我的 repo https://github.com/uzaymacar/attention-mechanisms。这里的层是为多对一序列任务而设计的,但可以通过细微的调整来适应其他形式。
推荐阅读
- eclipse - 为什么 Eclipse 遇到 NPE 更新我的 Maven 项目?
- c# - MailKit 附件写入 MemoryStream
- css - 如何在顺风中保持单个网格列的高度
- c++ - std::condition_variable 是否真的在阻塞之前解锁了给定的 unique_lock 对象?
- java - 具有嵌套分组依据的收集器(java 8)
- bluedata - 如何在一个集群中公开 HDFS 并在另一个集群/租户中使用 DataTap 访问它?
- swift - 使用正则表达式替换字符串中的组仅替换最后一个组
- memory-leaks - netty PoolChunk 中积累的内存
- javascript - 如何使用数字数组通过其 id 获取对象列表
- c# - 您如何正确等待所有信号量线程完成?