python - keras中的fit_generator,将所有内容加载到内存中
问题描述
我有原始数据(X_train,y_train),我正在将这些数据修改为其他数据。原始数据只是带有标签的图像。修改后的数据应该是 Siamese 网络的图像对,数量较多,内存约为 30 GB。所以不能运行这个函数来在整个原始数据上创建对。所以,我使用 keras fit_generator 认为它只会加载那个特定的批次。
我在样本对上同时运行了 model.fit 和 model.fit_generator,但我观察到两者都使用相同数量的内存。所以,我想我的代码在使用 fit_generator 时存在一些问题。下面是相关代码。你们能帮我解决这个问题吗?
下面的代码:
def create_pairs(X_train, y_train):
tr_pairs = []
tr_y = []
y_train = np.array(y_train)
digit_indices = [np.where(y_train == i)[0] for i in list(set(y_train))]
for i in range(len(digit_indices)):
n = len(digit_indices[i])
for j in range(n):
random_index = digit_indices[i][j]
anchor_image = X_train[random_index]
anchor_label = y_train[random_index]
anchor_indices = [i for i, x in enumerate(y_train) if x == anchor_label]
negate_indices = list(set(list(range(0,len(X_train)))) - set(anchor_indices))
for k in range(j+1,n):
support_index = digit_indices[i][k]
support_image = X_train[support_index]
tr_pairs += [[anchor_image,support_image]]
negate_index = random.choice(negate_indices)
negate_image = X_train[negate_index]
tr_pairs += [[anchor_image,negate_image]]
tr_y += [1,0]
return np.array(tr_pairs),np.array(tr_y)
def myGenerator():
tr_pairs, tr_y = create_pairs(X_train, y_train)
while 1:
for i in range(110): # 1875 * 32 = 60000 -> # of training samples
if i%125==0:
print("i = " + str(i))
yield [tr_pairs[i*32:(i+1)*32][:, 0], tr_pairs[i*32:(i+1)*32][:, 1]], tr_y[i*32:(i+1)*32]
model.fit_generator(myGenerator(), steps_per_epoch=110, epochs=2,
verbose=1, callbacks=None, validation_data=([te_pairs[:, 0], te_pairs[:, 1]], te_y), validation_steps=None, class_weight=None,
max_queue_size=10, workers=1, use_multiprocessing=False, shuffle=True, initial_epoch=0)
解决方案
myGenerator
返回一个生成器。
但是,您应该注意到create_pairs
正在将完整数据集加载到内存中。当你调用tr_pairs, tr_y = create_pairs(X_train, y_train)
数据集时加载,所以内存资源正在被使用。
myGenerator
简单地遍历一个已经在内存中的结构。
解决方案是自己制造create_pairs
一个发电机。
如果数据是一个 numpy 数组,我可以建议使用h5
文件从磁盘读取大块数据。
http://docs.h5py.org/en/latest/high/dataset.html#chunked-storage
推荐阅读
- flutter - 在 Flutter 中使用 isAntiAlias = false 在画布中渲染文本
- variables - 在 ZSH 中,如何将变量名用作另一个变量名的一部分?
- r - 当您在 R 中有超过 1 个唯一 id 时,是否有用于从宽转换为长的 R 代码?
- c++ - std::set_intersection 和迭代器
- javascript - 使用 Youtube 数据 api -v3 playlistItems 时如何获取观看次数和评论数
- reactjs - 使用 multer 进行图像上传时上传表单的问题。错误是Cast to string failed for value and internal server error
- c++ - pragma optimize vs pragma target 有什么区别
- python - 如何使用 python asyncio mongo 客户端电机列出数据库?
- html - 在一个表中使用两个 Flexbox
- python - 为什么 ico 图像不显示在使用 Tkinter 的窗口上?