首页 > 解决方案 > Tensorflow Keras 模型在使用 model.predict() 时出现 OOM,尽管使用 model.fit() 进行训练时运行没有问题


我正在使用 tf.keras 使用自定义数据生成器来读取和增强图像来训练图像分割模型。虽然训练模型工作正常(即没有内存问题),但当试图在我的测试集上预测我的 GPU(8GB,稍后参见 nvidia-smi)时内存不足。在训练后直接预测和重新启动内核后,使用model.load_weights()model.predict()之后加载模型以及训练中使用的相同批量大小(4,在训练期间使用约 6GB 内存)或批量大小为 1 的情况下都是这种情况两种批量大小都试图分配超过 8GB。

在训练期间,内存使用量稳定在 6GB 左右,但在使用时,model.predict()它从 6GB 左右开始,但在抛出之前大约 10 秒后跳转到 8GB ResourceExhaustedError(参见后面的堆栈跟踪)。这对我来说似乎非常违反直觉,我通过谷歌找到的提示(例如重新启动 python,从权重加载模型,然后预测释放预先使用的内存)没有奏效,所以任何帮助都会很棒。



Mon Aug  9 14:27:29 2021       
| NVIDIA-SMI 471.11       Driver Version: 471.11       CUDA Version: 11.4     |
| GPU  Name            TCC/WDDM | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA GeForce ... WDDM  | 00000000:0D:00.0  On |                  N/A |
| 56%   50C    P8    24W / 220W |   8057MiB /  8192MiB |      4%      Default |
|                               |                      |                  N/A |
| Processes:                                                                  |
|  GPU   GI   CI        PID   Type   Process name                  GPU Memory |
|        ID   ID                                                   Usage      |
|    0   N/A  N/A      1576    C+G   Insufficient Permissions        N/A      |
|    0   N/A  N/A      2292    C+G   ...kyb3d8bbwe\Calculator.exe    N/A      |
|    0   N/A  N/A      8316    C+G   C:\Windows\explorer.exe         N/A      |
|    0   N/A  N/A      8736    C+G   ...lPanel\SystemSettings.exe    N/A      |
|    0   N/A  N/A     11220    C+G   ...bbwe\Microsoft.Photos.exe    N/A      |
|    0   N/A  N/A     11740    C+G   ...5n1h2txyewy\SearchApp.exe    N/A      |
|    0   N/A  N/A     12280    C+G   ...ekyb3d8bbwe\YourPhone.exe    N/A      |
|    0   N/A  N/A     12820    C+G   ...8wekyb3d8bbwe\GameBar.exe    N/A      |
|    0   N/A  N/A     13820    C+G   ...perience\NVIDIA Share.exe    N/A      |
|    0   N/A  N/A     14552    C+G   ...nputApp\TextInputHost.exe    N/A      |
|    0   N/A  N/A     14848    C+G   ...y\ShellExperienceHost.exe    N/A      |
|    0   N/A  N/A     14976    C+G   ...zilla Firefox\firefox.exe    N/A      |
|    0   N/A  N/A     15688    C+G   ...udibleRT.WindowsPhone.exe    N/A      |
|    0   N/A  N/A     16628      C   ...Data\Anaconda3\python.exe    N/A      |
|    0   N/A  N/A     23648    C+G   ...aming\Spotify\Spotify.exe    N/A      |


class DataGenerator(tf.keras.utils.Sequence):
    def __init__(self, df, batch_size, mode="train", shuffle=True, augment=False, p_augment=0,
                 union=False, greyscale=False, normalize=True, dims=(256, 1600)):
        """DataGenerator usable for train/val/test splits"""
        self.df = df
        self.length = len(df)
        self.BATCH_SIZE = batch_size
        self.mode = mode
        self.shuffle = shuffle
        self.augment = augment
        self.p_augment = p_augment
        self.union = union
        self.greyscale = greyscale
        self.normalize = normalize
        self.dims = dims
        self.num_channels = 1 if greyscale else 3
        self.num_classes = 1 if union else 4
        self.indices = df.index.values.tolist() # will be reset anyways
        assert mode in ["train", "predict"], "DataGenerator mode is unsupported. Set it to \"train\" or \"predict\"."
        if augment:
            assert p_augment > 0 and p_augment <= 1, "Augmentation is turned on, but probability is zero or larger than one."
    def __len__(self):
        """number of batches in each epoch"""
        return int(np.floor(self.length / self.BATCH_SIZE))
    def on_epoch_end(self):
        """shuffle list of indices"""
        # called on the end of every epoch
        self.indices = self.df.index.values.tolist()
        if self.shuffle:
    def _load_img(self, img_path):
        """loads image in RGB/greyscale and normalizes it"""
        if self.greyscale:
            img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
            img = cv2.imread(img_path)
            img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        if self.normalize:
            img = img.astype(np.float32) / 255.
            img = img.astype(np.float32)
        return img
    def _gen_x(self, idx_list):
        """generates input values from list of indices"""
        res = np.empty(shape=(self.BATCH_SIZE, *self.dims, self.num_channels))
        for i, df_idx in enumerate(idx_list):
            img_path = self.df.loc[df_idx]["img_id"]
            img = self._load_img(img_path)
            if self.greyscale:
                res[i, ] = np.expand_dims(img, axis=-1)
                res[i, ] = img
        return res
    def _gen_tgt(self, idx_list):
        """generates target values from list of indices"""
        res = np.empty(shape=(self.BATCH_SIZE, *self.dims, self.num_classes))
        for i, df_idx in enumerate(idx_list):
            rles = self.df.loc[df_idx]["c1":"c_all"]
            if self.union:
                # return mask of all defect pixels (no diff between defect class)
                masks = build_masks(rles, union_only=True)
                masks = build_masks(rles)
            res[i, ] = masks
        return res
    def __getitem__(self, idx):
        """creates one batch of data"""
        # get indices of batch (self.indices is shuffled list of df indices)
        idxs = self.indices[idx*self.BATCH_SIZE:(idx+1)*self.BATCH_SIZE]
        x = self._gen_x(idxs)
        if self.mode == "predict":
            return x
        # mode is train -> get target data and possible augment
        tgt = self._gen_tgt(idxs)
        if self.augment:
            x, tgt = self._augment_batch(x, tgt)
        return x, tgt
    def _augment_batch(self, _x, _tgt):
        # flips img and masks vertically and/or horizontally with p_augment respectively
        for i in range(self.BATCH_SIZE):
            # flip up-down
            if random.random() > self.p_augment:
                if self.greyscale:
                    _x[i] = np.expand_dims(cv2.flip(_x[i], flipCode=0), axis=-1)
                    _x[i] = cv2.flip(_x[i], flipCode=0)
                _tgt[i] = cv2.flip(_tgt[i], flipCode=0)
            # flip left-right
            if random.random() > self.p_augment:
                if self.greyscale:
                    _x[i] = np.expand_dims(cv2.flip(_x[i], flipCode=1), axis=-1)
                    _x[i] = cv2.flip(_x[i], flipCode=1)
                _tgt[i] = cv2.flip(_tgt[i], flipCode=1)
        return _x, _tgt


from copy import deepcopy

# configs for train/val datagens
train_config = {"mode": "train",
               "batch_size": 4,
               "p_augment": 0.5,
               "union": False,
               "greyscale": False,
               "normalize": True,
               "dims": (256,1600)}

val_config = deepcopy(train_config)
val_config["shuffle"] = False
val_config["augment"] = False

train_datagen = DataGenerator(df_train, **train_config)
val_datagen = DataGenerator(df_val, **val_config)

# returns model with correct image dims and number of classes
model = get_model_from_generator(train_datagen)
model.compile(optimizer=Adam(learning_rate=1e-4), loss=bce_dice_loss,
                   metrics=["binary_crossentropy", dice_coef])

cb_es = tf.keras.callbacks.EarlyStopping(monitor="val_loss", patience=10)
cb_best = tf.keras.callbacks.ModelCheckpoint("models/fcn_rgb/cp_{epoch:02d}_{val_loss:.3f}.ckpt", monitor="val_loss",
                                             save_weights_only=True, save_best_only=True)
history = model.fit(x=train_datagen, callbacks=[cb_es, cb_best], epochs=100,


Epoch 1/100
2513/2513 [==============================] - 541s 207ms/step - loss: 0.2413 - binary_crossentropy: 0.5235 - dice_coef: 0.0205 - val_loss: -0.0034 - val_binary_crossentropy: 0.1481 - val_dice_coef: 0.0775
Epoch 2/100
2513/2513 [==============================] - 518s 206ms/step - loss: -0.1231 - binary_crossentropy: 0.0864 - dice_coef: 0.1663 - val_loss: -0.2862 - val_binary_crossentropy: 0.0627 - val_dice_coef: 0.3175



test_config = {"mode": "predict",
               "batch_size": 1,
               "p_augment": 0,
               "union": False,
               "greyscale": False,
               "normalize": True,
               "dims": (256,1600)}

test_datagen = DataGenerator(df_test, **test_config)
model = get_model_from_generator(train_datagen)
preds = model.predict(test_datagen)


ResourceExhaustedError                    Traceback (most recent call last)
<ipython-input-16-d0a77a4d2cd0> in <module>
----> 1 preds = model.predict(test_datagen)

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\keras\engine\training.py in predict(self, x, batch_size, verbose, steps, callbacks, max_queue_size, workers, use_multiprocessing)
   1627           for step in data_handler.steps():
   1628             callbacks.on_predict_batch_begin(step)
-> 1629             tmp_batch_outputs = self.predict_function(iterator)
   1630             if data_handler.should_sync:
   1631               context.async_wait()

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py in __call__(self, *args, **kwds)
    826     tracing_count = self.experimental_get_tracing_count()
    827     with trace.Trace(self._name) as tm:
--> 828       result = self._call(*args, **kwds)
    829       compiler = "xla" if self._experimental_compile else "nonXla"
    830       new_tracing_count = self.experimental_get_tracing_count()

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\def_function.py in _call(self, *args, **kwds)
    860       # In this case we have not created variables on the first call. So we can
    861       # run the first trace but we should fail if variables are created.
--> 862       results = self._stateful_fn(*args, **kwds)
    863       if self._created_variables:
    864         raise ValueError("Creating variables on a non-first call to a function"

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in __call__(self, *args, **kwargs)
   2940       (graph_function,
   2941        filtered_flat_args) = self._maybe_define_function(args, kwargs)
-> 2942     return graph_function._call_flat(
   2943         filtered_flat_args, captured_inputs=graph_function.captured_inputs)  # pylint: disable=protected-access

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in _call_flat(self, args, captured_inputs, cancellation_manager)
   1916         and executing_eagerly):
   1917       # No tape is watching; skip to running the function.
-> 1918       return self._build_call_outputs(self._inference_function.call(
   1919           ctx, args, cancellation_manager=cancellation_manager))
   1920     forward_backward = self._select_forward_and_backward_functions(

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\function.py in call(self, ctx, args, cancellation_manager)
    553       with _InterpolateFunctionError(self):
    554         if cancellation_manager is None:
--> 555           outputs = execute.execute(
    556               str(self.signature.name),
    557               num_outputs=self._num_outputs,

~\AppData\Roaming\Python\Python38\site-packages\tensorflow\python\eager\execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
     57   try:
     58     ctx.ensure_initialized()
---> 59     tensors = pywrap_tfe.TFE_Py_Execute(ctx._handle, device_name, op_name,
     60                                         inputs, attrs, num_outputs)
     61   except core._NotOkStatusException as e:

ResourceExhaustedError:  OOM when allocating tensor with shape[1,96,128,800] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc
     [[node model_1/batch_normalization_57/FusedBatchNormV3 (defined at <ipython-input-16-d0a77a4d2cd0>:1) ]]
Hint: If you want to see a list of allocated tensors when OOM happens, add report_tensor_allocations_upon_oom to RunOptions for current allocation info.

Function call stack:

我正在使用 tensorflow 2.4.1 版。

编辑:我忘了提一下,在使用以下代码进行训练和预测之前,我也尝试过激活 tfs 动态内存分配,但仍然出现错误。

# dynamic memory allocation
gpus = tf.config.list_physical_devices("GPU")
if gpus:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.experimental.list_logical_devices("GPU")
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        # must be set before GPUs have been initialized

标签: pythontensorflowmachine-learningkerasout-of-memory

