首页 > 解决方案 > 如何有效地组合许多 numpy 数组?

问题描述

我在尝试加载 18k 的训练数据以使用 tensorflow 进行训练时遇到了困难。这些文件是 npy 文件,命名如下:0.npy、1.npy...18000.npy。

我在网上四处寻找,想出了一个简单的代码,首先以正确的顺序读取文件,并尝试将训练数据连接在一起,但这需要很长时间。

import numpy as np
import glob
import re
import tensorflow as tf

print("TensorFlow version: {}".format(tf.__version__))

files = glob.glob('D:/project/train/*.npy')
files.sort(key=lambda var:[int(x) if x.isdigit() else x for x in 
           re.findall(r'[^0-9]|[0-9]+', var)])
# print(files)

final_dataset = []
i = 0
for file in files:    
    dataset = np.load(file, mmap_mode='r')
    print(i)
    #print("Size of dataset: {} ".format(dataset.shape))
    if (i==0):
      final_dataset = dataset
    else: 
      final_dataset = np.concatenate((final_dataset, dataset), axis = 0)
    i = i + 1

print("Size of final_dataset: {} ".format(final_dataset.shape))
np.save('combined_train.npy', final_dataset)

标签: pythontensorflow

解决方案


以任何方式“组合”数组都涉及(1),创建一个具有两个数组总大小的数组;(2)、将它们的内容复制到数组中。如果每次加载数组时都这样做,它会重复 18000 次 - 每次迭代的时间都会随着每次迭代而增长(由于更大的final_dataset)。

一个简单的解决方法是将数组附加到列表中 - 然后在最后将它们全部组合一次

dataset = []
for file in files:
    data = np.load(file, mmap_mode='r')
    dataset.append(data)

final_dataset = np.concatenate(dataset, axis=0)

要注意:一定要final_dataset真正适合你的 RAM,否则程序会崩溃。您可以通过ram_required = size_per_file * number_of_files. 相关SO。(为了进一步加快速度,您可以研究多处理- 但工作起来并不简单)


推荐阅读