pandas - 如何在 tensorflow / keras 中移动像 pandas.shift 这样的张量?(无需将最后一行移到第一行,如 tf.roll)
问题描述
我想在给定轴上移动张量。在 pandas 或 numpy 中很容易做到这一点。像这样:
import numpy as np
import pandas as pd
data = np.arange(0, 6).reshape(-1, 2)
pd.DataFrame(data).shift(1).fillna(0).values
输出是:
数组([[0., 0.],
[0., 1.],
[2., 3.]])
但在 tensorflow 中,我找到的最接近的解决方案是tf.roll
. 但它将最后一行移到第一行。(我不想要那个)。所以我必须使用类似的东西
tf.roll + tf.slice(remove the last row) + tf.concat(add tf.zeros to the first row)
.
真的很难看。
有没有更好的方法来处理shift
tensorflow 或 keras?
谢谢。
解决方案
将公认的答案推广到任意张量形状、所需位移和要位移的轴:
import tensorflow as tf
def tf_shift(tensor, shift=1, axis=0):
dim = len(tensor.shape)
if axis > dim:
raise ValueError(
f'Value of axis ({axis}) must be <= number of tensor axes ({dim})'
)
mask_dim = dim - axis
mask_shape = tensor.shape[-mask_dim:]
zero_dim = min(shift, mask_shape[0])
mask = tf.concat(
[tf.zeros(tf.TensorShape(zero_dim) + mask_shape[1:]),
tf.ones(tf.TensorShape(mask_shape[0] - zero_dim) + mask_shape[1:])],
axis=0
)
for i in range(dim - mask_dim):
mask = tf.expand_dims(mask, axis=0)
return tf.multiply(
tf.roll(tensor, shift, axis),
mask
)
推荐阅读
- javascript - 将 settings.json 文件与 Node.js 和 pkg 一起使用
- tmux - 有没有办法自动化在 tmux 中创建会话和窗格的过程
- python - Python 中的异常中断
- jenkins - 无法在 Jenkins 从节点上运行赛普拉斯测试?
- javascript - 在 Javascript 中使用 Maps 实现二维数组
- sql - 过滤相同的查询 3 次不同的时间。表现?
- python - signal.medfilt2d 得到 ValueError: object too deep for desired array
- machine-learning - 如何在 Imblearn 中获取已创建样本的索引
- reactjs - 如何使用 SVG 代码作为元素的来源?
- java - JPA 查询:java.lang.IllegalArgumentException:提供了至少 1 个参数,但查询中仅存在 0 个参数