首页 > 解决方案 > 增加 np.array 的大小

问题描述

我在形状为 (2000, 20, 28) 的 X 矩阵上运行 conv1D,批量大小为 2000、20 个时间步长和 28 个特征。我想继续使用 conv2D CNN 并将我的矩阵的维数增加到 (2000, 20, 28, 10) 有 10 个元素,我可以为这些元素构建一个 (2000, 20, 28) X 矩阵。类似地,我想得到一个大小为 (2000, 10) 的数组,即 5 倍于我用于 LSTM 和 Conv1D 网络的大小为 (2000, ) 的 y 数组。

我用来从输入 dataX、dataY 创建 20 个时间步的代码是

def LSTM_create_dataset(dataX, dataY, seq_length, step):
    Xs, ys = [], []
    for i in range(0, len(dataX) - seq_length, step):
        v = dataX.iloc[i:(i + seq_length)].values
        Xs.append(v)
        ys.append(dataY.iloc[i + seq_length])
    return np.array(Xs), np.array(ys)

我在准备创建 conv2D NN 数据的循环中使用此函数:

for ric in rics:
    dataX, dataY = get_model_data(dbInput, dbList, ric, horiz, drop_rows, triggerUp1, triggerLoss, triggerUp2 = 0)
    dataX = get_model_cleanXset(dataX, trigger)                             # Clean X matrix for insufficient data
    Xs, ys = LSTM_create_dataset(dataX, dataY, seq_length, step)        # slide over seq_length for a 3D matrix
    Xconv.append(Xs)
    yconv.append(ys)
    Xconv.append(Xs)
    yconv.append(ys)

我得到一个 (10, 2000, 20, 28) Xconv 矩阵而不是 (2000, 20, 28, 10) 目标输出矩阵 X 和一个 (10, 2000) 矩阵 y 而不是目标 (2000, 10)。我知道我可以很容易地用yconv = np.reshape(yconv, (2000, 5)). 但是 Xconv 的重塑功能Xconv = np.reshape(Xconv, (2000, 20, 28, 10))似乎很危险,因为我无法将输出可视化,甚至是错误的。我怎样才能安全地做到这一点(或者你能确认我的第一次尝试吗?提前非常感谢。

标签: pythonnumpytensorflowreshape

解决方案


如果您的 y 矩阵具有形状(10, 2000) ,那么您将无法将其塑造成您想要的(2000,5)。我在下面演示了这一点。

# create array of same shape as your original y
arr_1 = np.arange(0,2000*10).reshape(10,2000) 
print(arr_1.shape) # returns (10,2000)
arr_1 = arr_1.reshape(2000,5)

这将返回以下错误消息,因为前后形状的尺寸必须匹配非常重要。

ValueError: cannot reshape array of size 20000 into shape (2000,5)  

我不完全理解您无法可视化输出的声明 -reshape如果您愿意,您可以手动检查该函数是否正确执行,用于您的数据集(或其中的一小部分以确认该函数有效工作)使用 print陈述,如下 - 通过将输出与您的原始数据以及您期望数据之后的样子进行比较。

import numpy as np

arr = np.arange(0,2000) 
arr = arr.reshape(20,10,10,1) # reshape array to shape (20, 10, 10, 1)

# these statements let you examine the array contents at varying depths
print(arr[0][0][0])
print(arr[0][0])

推荐阅读