python - 增加 np.array 的大小
问题描述
我在形状为 (2000, 20, 28) 的 X 矩阵上运行 conv1D,批量大小为 2000、20 个时间步长和 28 个特征。我想继续使用 conv2D CNN 并将我的矩阵的维数增加到 (2000, 20, 28, 10) 有 10 个元素,我可以为这些元素构建一个 (2000, 20, 28) X 矩阵。类似地,我想得到一个大小为 (2000, 10) 的数组,即 5 倍于我用于 LSTM 和 Conv1D 网络的大小为 (2000, ) 的 y 数组。
我用来从输入 dataX、dataY 创建 20 个时间步的代码是
def LSTM_create_dataset(dataX, dataY, seq_length, step):
Xs, ys = [], []
for i in range(0, len(dataX) - seq_length, step):
v = dataX.iloc[i:(i + seq_length)].values
Xs.append(v)
ys.append(dataY.iloc[i + seq_length])
return np.array(Xs), np.array(ys)
我在准备创建 conv2D NN 数据的循环中使用此函数:
for ric in rics:
dataX, dataY = get_model_data(dbInput, dbList, ric, horiz, drop_rows, triggerUp1, triggerLoss, triggerUp2 = 0)
dataX = get_model_cleanXset(dataX, trigger) # Clean X matrix for insufficient data
Xs, ys = LSTM_create_dataset(dataX, dataY, seq_length, step) # slide over seq_length for a 3D matrix
Xconv.append(Xs)
yconv.append(ys)
Xconv.append(Xs)
yconv.append(ys)
我得到一个 (10, 2000, 20, 28) Xconv 矩阵而不是 (2000, 20, 28, 10) 目标输出矩阵 X 和一个 (10, 2000) 矩阵 y 而不是目标 (2000, 10)。我知道我可以很容易地用yconv = np.reshape(yconv, (2000, 5))
. 但是 Xconv 的重塑功能Xconv = np.reshape(Xconv, (2000, 20, 28, 10))
似乎很危险,因为我无法将输出可视化,甚至是错误的。我怎样才能安全地做到这一点(或者你能确认我的第一次尝试吗?提前非常感谢。
解决方案
如果您的 y 矩阵具有形状(10, 2000) ,那么您将无法将其塑造成您想要的(2000,5)。我在下面演示了这一点。
# create array of same shape as your original y
arr_1 = np.arange(0,2000*10).reshape(10,2000)
print(arr_1.shape) # returns (10,2000)
arr_1 = arr_1.reshape(2000,5)
这将返回以下错误消息,因为前后形状的尺寸必须匹配非常重要。
ValueError: cannot reshape array of size 20000 into shape (2000,5)
我不完全理解您无法可视化输出的声明 -reshape
如果您愿意,您可以手动检查该函数是否正确执行,用于您的数据集(或其中的一小部分以确认该函数有效工作)使用 print陈述,如下 - 通过将输出与您的原始数据以及您期望数据之后的样子进行比较。
import numpy as np
arr = np.arange(0,2000)
arr = arr.reshape(20,10,10,1) # reshape array to shape (20, 10, 10, 1)
# these statements let you examine the array contents at varying depths
print(arr[0][0][0])
print(arr[0][0])
推荐阅读
- python-3.x - 在 Window 7 上运行“从 skimage 导入数据”代码时发现“ModuleNotFoundError”
- dc.js - 为什么在折线图的上方和下方都呈现黑色区域?
- c - Xcode 中的 Git 忽略
- r - 如何在 R 的向量中包含变量?
- java - 如何在代码中的确切位置将数据从 firebase 提取到 android studio 中?
- python - 图聚类的 Louvain 算法在 Spark/Scala 和 Python 中运行时给出完全不同的结果,为什么会发生这种情况?
- python - 熊猫如何避免在 groupby nlargest n 中应用
- r - 有没有办法让 RMarkdown 选项卡即使在添加 Shiny 运行时也能显示?
- python - 获取绝对路径时出现问题?
- jquery - 使用Jquery离开ID div时如何清除mousemove