首页 > 解决方案 > 重复最后 X 值的批处理数组

问题描述

我有一个很长的数组,我想申请批处理。但此外,我想将最后一个X值引入新批次。

假设我想要 10 个值的批次,并且我想重复最后 2 个值。

import numpy as np

np.random.seed(9)
vals = np.random.randint(0, 9, 55)

出去:[5 6 8 6 1 6 4 8 1 8 5 1 0 8 8 8 2 6 8 1 8 3 5 3 6 7 0 8 1 8 1 6 6 2 8 4 5 3 4 0 8 0 4 5 4 8 3 8 4 8 0 1 2 3 7]

那么我的目标是:

[5 6 8 6 1 6 4 8 1 8]
[1 8 5 1 0 8 8 8 2 6]
[2 6 8 1 8 3 5 3 6 7]
[6 7 0 8 1 8 1 6 6 2]
[6 2 8 4 5 3 4 0 8 0]
[8 0 4 5 4 8 3 8 4 8]
[4 8 0 1 2 3 7]

如您所见,一个数组中的最后两个值是下一个数组的前两个。

我试图找到这个的逻辑,我找到了下一个:

bs, ct = 10, 2 # ct = X in my question

print(vals[bs*0-0:bs*1-0])
print(vals[bs*1-2:bs*2-2])
print(vals[bs*2-4:bs*3-4])
print(vals[bs*3-6:bs*4-6])
print(vals[bs*4-8:bs*5-8])
print(vals[bs*5-10:bs*6-10])
print(vals[bs*6-12:bs*7-12])

因此,我试图创建循环但没有工作,我相信它必须更容易。

print(vals[0 : bs])
for i in range(1, math.ceil(len(vals)/bs)):
    print(vals[bs*i-2**i : bs*(i+1)-2**i])

我尝试了以下操作:

# Without repeating values works ok but not my goal
for i in range(0, len(vals), bs):
    print(vals[i:i+bs])
for i in range(0, len(vals), bs):
    print(vals[i-ct:i+bs])

我正在尝试很多组合ctbt但总是遇到一些麻烦。有人可以帮我吗?我知道它必须更容易,但我找不到逻辑......

在不使用 for 循环的情况下更直接地存在一些其他选项?也许numpy?我发现np.split并且我认为也许np.reshape可以工作,但问题是重复 X 值。

谢谢!

标签: pythonarraysnumpy

解决方案


这可以使用切片来完成:

>>> bs, ct = 10, 2
>>> result = [vals[i: i+bs] for i in range(0, vals.size, (bs - ct))]
>>> result
[array([5, 6, 8, 6, 1, 6, 4, 8, 1, 8]),
 array([1, 8, 5, 1, 0, 8, 8, 8, 2, 6]),
 array([2, 6, 8, 1, 8, 3, 5, 3, 6, 7]),
 array([6, 7, 0, 8, 1, 8, 1, 6, 6, 2]),
 array([6, 2, 8, 4, 5, 3, 4, 0, 8, 0]),
 array([8, 0, 4, 5, 4, 8, 3, 8, 4, 8]),
 array([4, 8, 0, 1, 2, 3, 7])]

这样的事情也可以通过numpy.lib.stride_tricks.as_strided

>>> from numpy.lib.stride_tricks import as_strided
>>> steps = vals.itemsize
>>> as_strided(
        vals, 
        shape=(math.ceil(vals.size/(bs -ct)), 10), 
        strides=(steps*(bs - ct), steps)
)
array([[  5,   6,   8,   6,   1,   6,   4,   8,   1,   8],
       [  1,   8,   5,   1,   0,   8,   8,   8,   2,   6],
       [  2,   6,   8,   1,   8,   3,   5,   3,   6,   7],
       [  6,   7,   0,   8,   1,   8,   1,   6,   6,   2],
       [  6,   2,   8,   4,   5,   3,   4,   0,   8,   0],
       [  8,   0,   4,   5,   4,   8,   3,   8,   4,   8],
       [  4,   8,   0,   1,   2,   3,   7, 121, 274,  34]])

由于最后几个值不存在,因此会有一些垃圾值。尽管如此,as_strided除非您知道自己在做什么并且绝对需要它,否则应该避免使用它。请参阅文档

引入了一种更安全的替代方案numpy.lib.stride_tricks.sliding_window_view


推荐阅读