python - 广播和连接不规则张量
问题描述
我有一个参差不齐的尺寸张量[BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]
。我想用另一个形状张量的数据来增加最后一个轴[BATCH_SIZE, AUG_DIM]
。给定示例的每个时间步都会增加相同的值。
TIME_STEPS
如果每个示例的张量都没有参差不齐,我可以简单地重塑第二个张量,tf.repeat
然后使用tf.concat
:
import tensorflow as tf
# create data
# shape: [BATCH_SIZE, TIME_STEPS, EMBEDDING_DIM]
emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
# shape: [BATCH_SIZE, 1, AUG_DIM]
aug = tf.constant([[[8]], [[9]]])
# concat
aug = tf.repeat(aug, emb.shape[1], axis=1)
emb_aug = tf.concat([emb, aug], axis=-1)
这在衣衫褴褛时不起作用,emb
因为emb.shape[1]
它是未知的并且因示例而异:
# rag and remove padding
emb = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))
# reshape for augmentation - this doesn't work
aug = tf.repeat(aug, emb.shape[1], axis=1)
ValueError:尝试将具有不受支持的类型 (<class 'NoneType'>) 的值 (None) 转换为张量。
目标是创建一个参差不齐的张量emb_aug
,如下所示:
<tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3 ,9]]]>
有任何想法吗?
解决方案
最简单的方法是通过使用使您的参差不齐的张量成为常规张量tf.RaggedTensor.to_tensor()
,然后执行其余的解决方案。我假设您需要张量保持参差不齐。关键是row_lengths
在你的参差不齐的张量中找到每批的 ,然后使用这些信息来使你的增广张量参差不齐。
示例:
import tensorflow as tf
# data
emb = tf.constant([[[1, 2, 3], [4, 5, 6]], [[1, 2, 3], [0, 0, 0]]])
aug = tf.constant([[[8]], [[9]]])
# make embeddings ragged for testing
emb_r = tf.RaggedTensor.from_tensor(emb, padding=(0, 0, 0))
print(emb_r.shape)
# (2, None, 3)
在这里,我们将使用row_lengths
和的组合sequence_mask
来创建一个新的不规则张量。
# find the row lengths of the embeddings
rl = emb_r.row_lengths()
print(rl)
# tf.Tensor([2 1], shape=(2,), dtype=int64)
# find the biggest row length
max_rl = tf.math.reduce_max(rl)
print(max_rl)
# tf.Tensor(2, shape=(), dtype=int64)
# repeat the augmented data `max_rl` number of times
aug_t = tf.repeat(aug, repeats=max_rl, axis=1)
print(aug_t)
# tf.Tensor(
# [[[8]
# [8]]
#
# [[9]
# [9]]], shape=(2, 2, 1), dtype=int32)
# create a mask
msk = tf.sequence_mask(rl)
print(msk)
# tf.Tensor(
# [[ True True]
# [ True False]], shape=(2, 2), dtype=bool)
从这里我们可以tf.ragged.boolean_mask
用来使增强的数据参差不齐
# make the augmented data a ragged tensor
aug_r = tf.ragged.boolean_mask(aug_t, msk)
print(aug_r)
# <tf.RaggedTensor [[[8], [8]], [[9]]]>
# concatenate!
output = tf.concat([emb_r, aug_r], 2)
print(output)
# <tf.RaggedTensor [[[1, 2, 3, 8], [4, 5, 6, 8]], [[1, 2, 3, 9]]]>
您可以在此处找到支持不规则张量的 tensorflow 方法列表
推荐阅读
- excel - 依赖下拉excel
- hadoop - 在配置单元中更改列名后,列的值变为 NULL
- python - 将两个数据框的列与公差进行比较
- sql-server - SQL Server 2019 安装microsoft odbc driver 17 下载错误
- python - 如何通过 SSH 使用 Python Selenium
- haskell - Haskell中自定义列表中的元素相乘
- javascript - 在useEffect()中状态更改后酶mount()重新渲染?
- sql - 基于多个字段的重复相似数据
- r - 如何在R中的矩阵中找到互补行
- selenium - Selenium 没有点击 Instagram 上的 LIKE 按钮