python - 读取 .h5 文件非常慢
问题描述
我的数据以 .h5 格式存储。我使用数据生成器来拟合模型,它非常慢。下面提供了我的代码片段。
def open_data_file(filename, readwrite="r"):
return tables.open_file(filename, readwrite)
data_file_opened = open_data_file(os.path.abspath("../data/data.h5"))
train_generator, validation_generator, n_train_steps, n_validation_steps = get_training_and_validation_generators(
data_file_opened,
......)
在哪里:
def get_training_and_validation_generators(data_file, batch_size, ...):
training_generator = data_generator(data_file, training_list,....)
data_generator 函数如下:
def data_generator(data_file, index_list,....):
orig_index_list = index_list
while True:
x_list = list()
y_list = list()
if patch_shape:
index_list = create_patch_index_list(orig_index_list, data_file, patch_shape,
patch_overlap, patch_start_offset,pred_specific=pred_specific)
else:
index_list = copy.copy(orig_index_list)
while len(index_list) > 0:
index = index_list.pop()
add_data(x_list, y_list, data_file, index, augment=augment, augment_flip=augment_flip,
augment_distortion_factor=augment_distortion_factor, patch_shape=patch_shape,
skip_blank=skip_blank, permute=permute)
if len(x_list) == batch_size or (len(index_list) == 0 and len(x_list) > 0):
yield convert_data(x_list, y_list, n_labels=n_labels, labels=labels, num_model=num_model,overlap_label=overlap_label)
x_list = list()
y_list = list()
add_data() 如下:
def add_data(x_list, y_list, data_file, index, augment=False, augment_flip=False, augment_distortion_factor=0.25,
patch_shape=False, skip_blank=True, permute=False):
'''
add qualified x,y to the generator list
'''
# pdb.set_trace()
data, truth = get_data_from_file(data_file, index, patch_shape=patch_shape)
if np.sum(truth) == 0:
return
if augment:
affine = np.load('affine.npy')
data, truth = augment_data(data, truth, affine, flip=augment_flip, scale_deviation=augment_distortion_factor)
if permute:
if data.shape[-3] != data.shape[-2] or data.shape[-2] != data.shape[-1]:
raise ValueError("To utilize permutations, data array must be in 3D cube shape with all dimensions having "
"the same length.")
data, truth = random_permutation_x_y(data, truth[np.newaxis])
else:
truth = truth[np.newaxis]
if not skip_blank or np.any(truth != 0):
x_list.append(data)
y_list.append(truth)
模型训练:
def train_model(model, model_file,....):
model.fit(training_generator,
steps_per_epoch=steps_per_epoch,
epochs=n_epochs,
verbose = 2,
validation_data=validation_generator,
validation_steps=validation_steps)
我的数据集很大:data.h5 为 55GB。完成一个纪元大约需要 7000 秒。并且在 6 个 epoch 之后出现分段错误错误。批量大小设置为 1,否则会出现资源耗尽错误。是否有一种有效的方法来读取生成器中的 data.h5 以便训练更快并且不会导致内存不足错误?
解决方案
这是我回答的开始。我查看了您的代码,您有很多调用来读取 .h5 数据。training_list
据我统计,生成器对和 上的每个循环进行 6 次读取调用validation_list
。所以,在一个训练循环上,这几乎是 20k 次调用。(对我来说)不清楚是否在每个训练循环上都调用了生成器。如果是,则乘以 2268 个循环。
HDF5 文件读取的效率取决于读取数据的调用次数(而不仅仅是数据量)。换句话说,一次调用读取 1GB 的数据比一次读取 1000 次调用 x 1MB 的数据要快。因此,我们需要确定的第一件事是从 HDF5 文件中读取数据所花费的时间(与您的 7000 相比)。
我隔离了读取数据文件的 PyTables 调用。由此,我构建了一个简单的程序来模仿生成器函数的行为。目前,它在整个样本列表上进行单个训练循环。如果要运行更长的测试,请增加n_train
和值。n_epoch
(注意:代码语法是正确的。但是没有文件,所以无法验证逻辑。我认为它是正确的,但您可能需要修复一些小错误。)
请参阅下面的代码。它应该独立运行(所有依赖项都已导入)。它打印基本的计时数据。运行它来对您的生成器进行基准测试。
import tables as tb
import numpy as np
from random import shuffle
import time
with tb.open_file('../data/data.h5', 'r') as data_file:
n_train = 1
n_epochs = 1
loops = n_train*n_epochs
for e_cnt in range(loops):
nb_samples = data_file.root.truth.shape[0]
sample_list = list(range(nb_samples))
shuffle(sample_list)
split = 0.80
n_training = int(len(sample_list) * split)
training_list = sample_list[:n_training]
validation_list = sample_list[n_training:]
start = time.time()
for index_list in [ training_list, validation_list ]:
shuffle(index_list)
x_list = list()
y_list = list()
while len(index_list) > 0:
index = index_list.pop()
brain_width = data_file.root.brain_width[index]
x = np.array([modality_img[index,0,
brain_width[0,0]:brain_width[1,0]+1,
brain_width[0,1]:brain_width[1,1]+1,
brain_width[0,2]:brain_width[1,2]+1]
for modality_img in [data_file.root.t1,
data_file.root.t1ce,
data_file.root.flair,
data_file.root.t2]])
y = data_file.root.truth[index, 0,
brain_width[0,0]:brain_width[1,0]+1,
brain_width[0,1]:brain_width[1,1]+1,
brain_width[0,2]:brain_width[1,2]+1]
x_list.append(data)
y_list.append(truth)
print(f'For loop:{e_cnt}')
print(f'Time to read all data={time.time()-start:.2f}')
推荐阅读
- common-lisp - 是否可以将 Common Lisp 的 `getf` 函数默认值更改为 NIL 以外的其他值?
- regex - notepad ++如何保持所有行以#开头并删除该行的其余部分
- mobile - 无法在 UI Automator 查看器中的 DateWidget 上选择数字
- mysql - Sequelize 查询返回错误响应
- validation - JSF2.2 FacesMessages 未显示(复合组件)
- android - 如何在 Android API 30+ 中在屏幕顶部创建 Toast
- python - 带有 IndexError 消息的并行列表
- browser - 'touchmove' UIEvent 没有前面的'touchstart'?
- ios - 我在 main.storyboard -> UIImageView -> Xcode 上的图像范围上看不到应用程序图标图像
- extjs - ExtJS 7.2 图表缺少标记