首页 > 解决方案 > Keras:带有自定义生成器的 badalloc

问题描述

我在 Ubuntu 17.04 VM 上使用带有 tensorflow-gpu 后端的 keras。我创建了一个自定义生成器来从 pickle 文件中读取输入和类,但似乎出现以下错误:

在抛出 'std::ba d_alloc' what() 的实例后调用终止:std::bad_alloc

加载数据的代码可以在这里看到:

    def data_gen(self, pklPaths, batch_size=16):
        while True:
            data = []
            labels = []
            for i, pklPath in enumerate(pklPaths):
                # print(pklPath)
                image = pickle.load(open(pklPath, 'rb'))
                for i in range(batch_size):
                    # Set a label
                    data.append(image[0][0])
                    labels.append(image[1][1])
                yield np.array(data), np.array(labels)

然后在火车部分我使用了一个合适的生成器:

vm_model.fit_generator(vm.data_gen(pkl_train), validation_data=vm.data_gen(pkl_validate), epochs=15, verbose=2,
                       steps_per_epoch=(5000/16), validation_steps=(1000/16), callbacks=[tb])

生成器应该比加载所有内容具有更好的内存管理,但似乎并非如此!有任何想法吗?

标签: python-3.xtensorflowkeras

解决方案


好的,所以我发现了问题,所以我正在回答我自己的问题。基本上,前一个有一个不必要的循环,并且还不断增加数据和标签的大小,基本上将整个数据加载到内存中:

    def data_gen(self, pklPaths, batch_size=16):
    while True:
        data = []
        labels = []
        for i, pklPath in enumerate(pklPaths):
            # load pickle
            image = pickle.load(open(pklPath, 'rb'))
            # append
            data.append(image[0][0])
            labels.append(image[1])
            # if batch is complete yield data and labels and reset
            if i % batch_size == 0 and i != 0:
                yield np.array(data), np.array(labels)
                data.clear()
                labels.clear()

推荐阅读