tensorflow - tf.nn.ctc_beam_search_decoder 返回奇怪的形状
问题描述
我正在尝试在变压器架构上使用 tf.nn.ctc_beam_search_decoder 但我无法理解发生了什么:(
我必须遵循代码:
# enc_output is matrix of size (batch, frames, target_vocab_size)
# and enc_gloss_loss is a keras dense layer
gloss_scores = self.enc_gloss_loss(enc_output) # result shape (32, 118, 1615)
gloss_probabilities = tf.nn.log_softmax(gloss_scores, axis=-1)
gloss_probabilities = tf.transpose(gloss_probabilities, perm=[1, 0, 2]) # result shape (118, 32, 1615)
seq = tf.fill(tf.shape(gloss_probabilities)[1], tf.shape(gloss_probabilities)[0])
# result shape (32,) where each element equals 118
ctc_decode, _ = tf.nn.ctc_beam_search_decoder(inputs = gloss_probabilities,
sequence_length = seq,
beam_width=1,
top_paths=1,
)
ctc_decode = ctc_decode[0]
> on the first batch, ctc_decode returns a shape (32, 164)
> on the second batch, ctc_decode returns a shape (32, 115) and
> on the third batch onward, it returns a shape (32, 1) ... WHY 1 ??????
以前有人经历过吗?使用 ctc_beam_search_decoder 时我做错了吗?
提前致谢
解决方案
推荐阅读
- python - 使用 sklearn 管道时出现 ValueError:数组不得包含 infs 或 NaN
- r - 如何按降序排列这个 ggplot2 图?
- java - 为什么我可以做清单
.toArray() 但不列出 .toArray() - python - 通过使用 python 从 yaml 文件中读取数据来创建 sql 文件
- pytorch - 使用 pytorch 张量进行维度扩展
- python - 通过记忆实现最小的硬币数量以进行更改?
- drake - 如何获得动态,我们可以在下一步中应用渐变(重新打开)
- spring - junit 测试无法捕获异常
- database - 使用飞镖将sqlite数据库加载到内存中?
- keycloak - 通过跳过登录页面进行 Keycloak 登录