首页 > 解决方案 > 首次启动时 Catalyst SupervisedRunner 断言错误

问题描述

我试图用它catalyst来训练pytorch我创建的自定义神经网络,但是,当我第一次在 Jupiter 中运行代码时,它总是给我一个AssertionError没有明确解释的问题。当我第二次运行该单元时,它似乎工作正常。怎么会发生这种情况以及如何解决?

from catalyst import dl

runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")
runner.train(
     model= net,
     criterion= criterion,
     optimizer= optimizer,
     loaders= {"training": train, "val": val}, 
     num_epochs=5,
     callbacks=[  
         dl.AccuracyCallback(input_key="logits", target_key="targets",  
         topk_args=(1, 3, 5)),
         dl.ConfusionMatrixCallback(input_key="logits", 
         target_key="targets", 
         num_classes=6),
     ],
     logdir="./logs",
     valid_loader="val",  
     valid_metric="loss",
     minimize_valid_metric=True,
     verbose=True,
     load_best_on_end=True,
     seed= 42
 ) 

错误的完整追溯:

AssertionError                            Traceback (most recent call last)
<ipython-input-15-4eae1ab61f41> in <module>
      2 
      3 runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")
----> 4 runner.train(
      5      model= net,
      6      criterion= criterion,

~/venv/lib/python3.9/site-packages/catalyst/runners/runner.py in train(self, loaders, model, engine, trial, criterion, optimizer, scheduler, callbacks, loggers, seed, hparams, num_epochs, logdir, valid_loader, valid_metric, minimize_valid_metric, verbose, timeit, check, overfit, load_best_on_end, fp16, amp, apex, ddp)
    513         self._load_best_on_end = load_best_on_end
    514         # run
--> 515         self.run()
    516 
    517     @torch.no_grad()

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)
    852             self.exception = ex
    853             self._run_event("on_experiment_end")
--> 854             self._run_event("on_exception")
    855         return self
    856 

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)
    786             getattr(callback, event)(self)
    787         if _has_str_intersections(event, ("_end", "_exception")):
--> 788             getattr(self, event)(self)
    789 
    790     @abstractmethod

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_exception(self, runner)
    778     def on_exception(self, runner: "IRunner"):
    779         """Event handler."""
--> 780         raise self.exception
    781 
    782     def _run_event(self, event: str) -> None:

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)
    848         """
    849         try:
--> 850             self._run_experiment()
    851         except (Exception, KeyboardInterrupt) as ex:
    852             self.exception = ex

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_experiment(self)
    838         self._run_event("on_experiment_start")
    839         for self.stage_key in self.stages:
--> 840             self.engine.spawn(self._run_stage)
    841         self._run_event("on_experiment_end")
    842 

~/venv/lib/python3.9/site-packages/catalyst/core/engine.py in spawn(self, fn, *args, **kwargs)
    136             wrapped function (if needed).
    137         """
--> 138         return fn(*args, **kwargs)
    139 
    140     def setup_process(self, rank: int = -1, world_size: int = 1):

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_stage(self, rank, world_size)
    829         self._run_event("on_stage_start")
    830         while self.stage_epoch_step < self.stage_epoch_len:
--> 831             self._run_epoch()
    832             if self.need_early_stop:
    833                 self.need_early_stop = False

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_epoch(self)
    822         self._run_event("on_epoch_start")
    823         for self.loader_key, self.loader in self.loaders.items():
--> 824             self._run_loader()
    825         self._run_event("on_epoch_end")
    826 

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_loader(self)
    809         # as it was noted in docs:
    810         # https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training
--> 811         self._run_event("on_loader_start")
    812         with torch.set_grad_enabled(self.is_train_loader):
    813             for self.loader_batch_step, self.batch in enumerate(self.loader):

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)
    782     def _run_event(self, event: str) -> None:
    783         if _has_str_intersections(event, ("_start",)):
--> 784             getattr(self, event)(self)
    785         for callback in self.callbacks.values():
    786             getattr(callback, event)(self)

~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_loader_start(self, runner)
    707         self.is_valid_loader: bool = self.loader_key.startswith("valid")
    708         self.is_infer_loader: bool = self.loader_key.startswith("infer")
--> 709         assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader
    710         self.loader_batch_size: int = _get_batch_size(self.loader)
    711         self.loader_batch_len: int = len(self.loader)

AssertionError: 

标签: pythonpython-3.xmachine-learningpytorchpytorch-lightning

解决方案


推荐阅读