首页 > 解决方案 > 张量流中部分输入的排列

问题描述

我有一个 NN inputs = [A, B],其中AB都是 Nd 数组,其形状(N, ...)- 即大小的第一维N(= 训练事件的数量)是对齐的。

现在我想用[A, B]targety = np.ones(N)[A, permute(B)]on 上的输入来训练我的 NN y = np.zeros(N)。我可以通过构建我的输入来实现这一点,例如:

inputs = [np.vstack([A, A]), np.vstack(B, np.random.permutation(B)]
y = np.concatenate([np.ones(N), np.zeros(N)])

但是,这意味着要对设备进行大量复制。有没有办法直接在设备上通过 tensorflow 实现?我知道tf.data它的洗牌能力,但这并不符合我的意图。培训仍应在整体上洗牌的输入+目标上进行。

标签: pythonnumpytensorflow

解决方案


事实证明,在 a 中可以有多个嵌套的改组tf.data.Dataset,从而解决了这个问题。

inputs = tf.data.Dataset.zip(
    tf.data.Dataset.concatenate(A, A), 
    tf.data.Dataset.concatenate(B, B.shuffle(N)),
    tf.data.Dataset.concatenate(ones, zeros)
    )

然后像这样一起洗牌:

inputs.shuffle(2*N)

推荐阅读