首页 > 解决方案 > 获取火车数据生成器并保存它们

问题描述

我正在开发一种更快的 RCNN 算法。我有一个生成 ground_truth 锚点的函数 get_anchor_gt。我在以下代码中调用该函数

# Get train data generator which generate X, Y, image_data
data_gen_train = get_anchor_gt(train_imgs, C, get_img_output_length, mode='train')

当我执行以下代码时

X, Y, image_data, debug_img, debug_num_pos = next(data_gen_train)

它开始训练并保存在 jupyter 笔记本中,大小增加到 200 mbs 和笔记本粉碎。有没有办法可以保存在硬盘中的某个地方,然后再加载它们?

这是 get_anchor_gt

#Generate the ground_truth anchors

def get_anchor_gt(all_img_data, C, img_length_calc_function, mode='train'):
""" Yield the ground-truth anchors as Y (labels)

Args:
    all_img_data: list(filepath, width, height, list(bboxes))
    C: config
    img_length_calc_function: function to calculate final layer's feature map (of base model) size according to input image size
    mode: 'train' or 'test'; 'train' mode need augmentation

Returns:
    x_img: image data after resized and scaling (smallest size = 300px)
    Y: [y_rpn_cls, y_rpn_regr]
    img_data_aug: augmented image data (original image with augmentation)
    debug_img: show image for debug
    num_pos: show number of positive anchors for debug
"""
while True:

    for img_data in all_img_data:
        try:

            # read in image, and optionally add augmentation

            if mode == 'train':
                img_data_aug, x_img = augment(img_data, C, augment=True)
            else:
                img_data_aug, x_img = augment(img_data, C, augment=False)

            (width, height) = (img_data_aug['width'], img_data_aug['height'])
            (rows, cols, _) = x_img.shape

            assert cols == width
            assert rows == height

            # get image dimensions for resizing
            (resized_width, resized_height) = get_new_img_size(width, height, C.im_size)

            # resize the image so that smalles side is length = 300px
            x_img = cv2.resize(x_img, (resized_width, resized_height), interpolation=cv2.INTER_CUBIC)
            debug_img = x_img.copy()

            try:
                y_rpn_cls, y_rpn_regr, num_pos = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function)
            except:
                continue

            # Zero-center by mean pixel, and preprocess image

            x_img = x_img[:,:, (2, 1, 0)]  # BGR -> RGB
            x_img = x_img.astype(np.float32)
            x_img[:, :, 0] -= C.img_channel_mean[0]
            x_img[:, :, 1] -= C.img_channel_mean[1]
            x_img[:, :, 2] -= C.img_channel_mean[2]
            x_img /= C.img_scaling_factor

            x_img = np.transpose(x_img, (2, 0, 1))
            x_img = np.expand_dims(x_img, axis=0)

            y_rpn_regr[:, y_rpn_regr.shape[1]//2:, :, :] *= C.std_scaling

            x_img = np.transpose(x_img, (0, 2, 3, 1))
            y_rpn_cls = np.transpose(y_rpn_cls, (0, 2, 3, 1))
            y_rpn_regr = np.transpose(y_rpn_regr, (0, 2, 3, 1))

            yield np.copy(x_img), [np.copy(y_rpn_cls), np.copy(y_rpn_regr)], img_data_aug, debug_img, num_pos

        except Exception as e:
            print(e)
            continue

标签: pythonsaveconv-neural-networkgenerator

解决方案


查看您的代码,我认为它默默地失败了:

try:
    y_rpn_cls, y_rpn_regr, num_pos = calc_rpn(C, img_data_aug, width, height, resized_width, resized_height, img_length_calc_function)
except:
    continue

您没有提供任何有关异常的通知,因此如果此行失败(无论出于何种原因,甚至是拼写错误或参数数量错误),您将永远看不到发生了什么,脚本将继续运行 for 循环。

首先,您可以删除try/except并明确查看问题所在,使用该信息您可以继续调试。

更多提示:

  1. 不要使用try/except,而是:
try:
    # ..code
except Exception as exc:
    # something went wrong, at least print the exception
    print(exc)
  1. python有一个调试器,pdb你可以开始逐行调试程序并检查变量值。开始python -m pdb yourprogram.py并继续调试。为了做到这一点,您应该学习一些pdb命令,但这现在很有用,请查看TutorialOfficial Docs

推荐阅读