首页 > 解决方案 > 如何在 tensorflow / keras 中移动像 pandas.shift 这样的张量?(无需将最后一行移到第一行,如 tf.roll)

问题描述

我想在给定轴上移动张量。在 pandas 或 numpy 中很容易做到这一点。像这样:

import numpy as np
import pandas as pd

data = np.arange(0, 6).reshape(-1, 2)
pd.DataFrame(data).shift(1).fillna(0).values

输出是:

数组([[0., 0.],
[0., 1.],
[2., 3.]])

但在 tensorflow 中,我找到的最接近的解决方案是tf.roll. 但它将最后一行移到第一行。(我不想要那个)。所以我必须使用类似的东西

tf.roll + tf.slice(remove the last row) + tf.concat(add tf.zeros to the first row).

真的很难看

有没有更好的方法来处理shifttensorflow 或 keras?

谢谢。

标签: pandasnumpytensorflowkerasshift

解决方案


将公认的答案推广到任意张量形状、所需位移和要位移的轴:

import tensorflow as tf

def tf_shift(tensor, shift=1, axis=0):
    dim = len(tensor.shape)

    if axis > dim:
        raise ValueError(
            f'Value of axis ({axis}) must be <= number of tensor axes ({dim})'
        )

    mask_dim = dim - axis
    mask_shape = tensor.shape[-mask_dim:]
    zero_dim = min(shift, mask_shape[0])

    mask = tf.concat(
        [tf.zeros(tf.TensorShape(zero_dim) + mask_shape[1:]),
         tf.ones(tf.TensorShape(mask_shape[0] - zero_dim) + mask_shape[1:])],
        axis=0
    )

    for i in range(dim - mask_dim):
        mask = tf.expand_dims(mask, axis=0)

    return tf.multiply(
        tf.roll(tensor, shift, axis),
        mask
    )

推荐阅读