python - 首次启动时 Catalyst SupervisedRunner 断言错误
问题描述
我试图用它catalyst
来训练pytorch
我创建的自定义神经网络,但是,当我第一次在 Jupiter 中运行代码时,它总是给我一个AssertionError
没有明确解释的问题。当我第二次运行该单元时,它似乎工作正常。怎么会发生这种情况以及如何解决?
from catalyst import dl
runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")
runner.train(
model= net,
criterion= criterion,
optimizer= optimizer,
loaders= {"training": train, "val": val},
num_epochs=5,
callbacks=[
dl.AccuracyCallback(input_key="logits", target_key="targets",
topk_args=(1, 3, 5)),
dl.ConfusionMatrixCallback(input_key="logits",
target_key="targets",
num_classes=6),
],
logdir="./logs",
valid_loader="val",
valid_metric="loss",
minimize_valid_metric=True,
verbose=True,
load_best_on_end=True,
seed= 42
)
错误的完整追溯:
AssertionError Traceback (most recent call last)
<ipython-input-15-4eae1ab61f41> in <module>
2
3 runner = dl.SupervisedRunner(input_key="features", output_key="logits", target_key="targets", loss_key="loss")
----> 4 runner.train(
5 model= net,
6 criterion= criterion,
~/venv/lib/python3.9/site-packages/catalyst/runners/runner.py in train(self, loaders, model, engine, trial, criterion, optimizer, scheduler, callbacks, loggers, seed, hparams, num_epochs, logdir, valid_loader, valid_metric, minimize_valid_metric, verbose, timeit, check, overfit, load_best_on_end, fp16, amp, apex, ddp)
513 self._load_best_on_end = load_best_on_end
514 # run
--> 515 self.run()
516
517 @torch.no_grad()
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)
852 self.exception = ex
853 self._run_event("on_experiment_end")
--> 854 self._run_event("on_exception")
855 return self
856
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)
786 getattr(callback, event)(self)
787 if _has_str_intersections(event, ("_end", "_exception")):
--> 788 getattr(self, event)(self)
789
790 @abstractmethod
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_exception(self, runner)
778 def on_exception(self, runner: "IRunner"):
779 """Event handler."""
--> 780 raise self.exception
781
782 def _run_event(self, event: str) -> None:
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in run(self)
848 """
849 try:
--> 850 self._run_experiment()
851 except (Exception, KeyboardInterrupt) as ex:
852 self.exception = ex
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_experiment(self)
838 self._run_event("on_experiment_start")
839 for self.stage_key in self.stages:
--> 840 self.engine.spawn(self._run_stage)
841 self._run_event("on_experiment_end")
842
~/venv/lib/python3.9/site-packages/catalyst/core/engine.py in spawn(self, fn, *args, **kwargs)
136 wrapped function (if needed).
137 """
--> 138 return fn(*args, **kwargs)
139
140 def setup_process(self, rank: int = -1, world_size: int = 1):
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_stage(self, rank, world_size)
829 self._run_event("on_stage_start")
830 while self.stage_epoch_step < self.stage_epoch_len:
--> 831 self._run_epoch()
832 if self.need_early_stop:
833 self.need_early_stop = False
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_epoch(self)
822 self._run_event("on_epoch_start")
823 for self.loader_key, self.loader in self.loaders.items():
--> 824 self._run_loader()
825 self._run_event("on_epoch_end")
826
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_loader(self)
809 # as it was noted in docs:
810 # https://pytorch.org/docs/stable/notes/amp_examples.html#typical-mixed-precision-training
--> 811 self._run_event("on_loader_start")
812 with torch.set_grad_enabled(self.is_train_loader):
813 for self.loader_batch_step, self.batch in enumerate(self.loader):
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in _run_event(self, event)
782 def _run_event(self, event: str) -> None:
783 if _has_str_intersections(event, ("_start",)):
--> 784 getattr(self, event)(self)
785 for callback in self.callbacks.values():
786 getattr(callback, event)(self)
~/venv/lib/python3.9/site-packages/catalyst/core/runner.py in on_loader_start(self, runner)
707 self.is_valid_loader: bool = self.loader_key.startswith("valid")
708 self.is_infer_loader: bool = self.loader_key.startswith("infer")
--> 709 assert self.is_train_loader or self.is_valid_loader or self.is_infer_loader
710 self.loader_batch_size: int = _get_batch_size(self.loader)
711 self.loader_batch_len: int = len(self.loader)
AssertionError:
解决方案
推荐阅读
- html - 将表单添加到按钮
- vue.js - Vue - 带有模板 vslots 的错误或设计
- java - 我在 JPA 查询中添加的每个谓词都嵌套在括号“()”中
- python - 在创建对象类 Python 时,如何确保用户输入正确的类型(在我的例子中是列表)
- pytorch - 在 ImageNet 中使用类的子集
- python - AttributeError:“成员”对象没有属性“客户端”
- javascript - Javascript - 如何将base64 pdf数据转换和下载为png图像?
- c# - 如何注入命名记录器通用 ILogger
使用 IServiceCollection 和 NLog 作为 ILogger 进入构造函数 - sql - Spring Data:JPA 存储库返回 Map 而不是 List?
- swift - 无论接受或拒绝媒体库权限,MPMusicController 播放/暂停 API 功能