编程技术网

关注微信公众号,定时推送前沿、专业、深度的编程技术资料。

 找回密码
 立即注册

QQ登录

只需一步,快速开始

极客时间

从检查站 RESTful train 时,PytorchProfiler crash

作者: Timerunning 2022-5-12 16:14:44 显示全部楼层 |阅读模式

PyTorchProfiler Crash when resuming training from checkpoint

Bug

RESTful 时期时,PytorchProfiler crash 等于教练的最大训练时代

To Reproduce

colab bug报告无聊模型

相关代码:

)">
trainer = Trainer(
  ***,
  max_epoch=1,  # !!! example max epoch 1
  profiler='pytorch'
)
# resume training from some example epoch 1 checkpoint
trainer.fit(model, datamodule, ckpt_path=<some-epoch1.ckpt>)

错误信息:

21 ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt') 22 # !!! add load path /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 767 self.strategy.model = model 768 self._call_and_handle_interrupt( --> 769 self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path 770 ) 771 /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs) 719 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) 720 else: --> 721 return trainer_fn(*args, **kwargs) 722 # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7 723 except KeyboardInterrupt as exception: /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path) 807 ckpt_path, model_provided=True, model_connected=self.lightning_module is not None 808 ) --> 809 results = self._run(model, ckpt_path=self.ckpt_path) 810 811 assert self.state.stopped /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path) 1246 1247 log.detail(f"{self.__class__.__name__}: calling teardown hooks") -> 1248 self._call_teardown_hook() 1249 1250 self.state.status = TrainerStatus.FINISHED /usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_teardown_hook(self) 1530 1531 # summarize profile results -> 1532 self.profiler.describe() 1533 1534 def call_hook( /usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/profiler.py in describe(self) 131 # manually instead of letting the `Trainer` do it through `setup` and `teardown` 132 self._prepare_streams() --> 133 summary = self.summary() 134 if summary and self._write_stream is not None: 135 self._write_stream(summary) /usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in summary(self) 453 return "" 454 --> 455 self._delete_profilers() 456 457 if not self.function_events: /usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in _delete_profilers(self) 492 if self.profiler is not None: 493 self.profiler.__exit__(None, None, None) --> 494 self._cache_functions_events() 495 self.profiler = None 496 /usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in _cache_functions_events(self) 487 if self._emit_nvtx: 488 return --> 489 self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events 490 491 def _delete_profilers(self) -> None: /usr/local/lib/python3.7/dist-packages/torch/profiler/profiler.py in events(self) 154 to be used in the trace callback or after the profiling is finished 155 """ --> 156 assert self.profiler 157 return self.profiler.function_events 158 AssertionError: ">
     19     trainer.fit(
     20         model, train_dataloaders=train_data, val_dataloaders=val_data,
---> 21         ckpt_path='lightning_logs/version_0/checkpoints/last.ckpt')
     22     # !!! add load path

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in fit(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    767         self.strategy.model = model
    768         self._call_and_handle_interrupt(
--> 769             self._fit_impl, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path
    770         )
    771 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_and_handle_interrupt(self, trainer_fn, *args, **kwargs)
    719                 return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
    720             else:
--> 721                 return trainer_fn(*args, **kwargs)
    722         # TODO: treat KeyboardInterrupt as BaseException (delete the code below) in v1.7
    723         except KeyboardInterrupt as exception:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _fit_impl(self, model, train_dataloaders, val_dataloaders, datamodule, ckpt_path)
    807             ckpt_path, model_provided=True, model_connected=self.lightning_module is not None
    808         )
--> 809         results = self._run(model, ckpt_path=self.ckpt_path)
    810 
    811         assert self.state.stopped

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _run(self, model, ckpt_path)
   1246 
   1247         log.detail(f"{self.__class__.__name__}: calling teardown hooks")
-> 1248         self._call_teardown_hook()
   1249 
   1250         self.state.status = TrainerStatus.FINISHED

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py in _call_teardown_hook(self)
   1530 
   1531         # summarize profile results
-> 1532         self.profiler.describe()
   1533 
   1534     def call_hook(

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/profiler.py in describe(self)
    131         # manually instead of letting the `Trainer` do it through `setup` and `teardown`
    132         self._prepare_streams()
--> 133         summary = self.summary()
    134         if summary and self._write_stream is not None:
    135             self._write_stream(summary)

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in summary(self)
    453             return ""
    454 
--> 455         self._delete_profilers()
    456 
    457         if not self.function_events:

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in _delete_profilers(self)
    492         if self.profiler is not None:
    493             self.profiler.__exit__(None, None, None)
--> 494             self._cache_functions_events()
    495             self.profiler = None
    496 

/usr/local/lib/python3.7/dist-packages/pytorch_lightning/profiler/pytorch.py in _cache_functions_events(self)
    487         if self._emit_nvtx:
    488             return
--> 489         self.function_events = self.profiler.events() if _KINETO_AVAILABLE else self.profiler.function_events
    490 
    491     def _delete_profilers(self) -> None:

/usr/local/lib/python3.7/dist-packages/torch/profiler/profiler.py in events(self)
    154         to be used in the trace callback or after the profiling is finished
    155         """
--> 156         assert self.profiler
    157         return self.profiler.function_events
    158 

AssertionError: 
Expected behavior

RESTful 时期等于最大时期,意味着 fitting 阶段是跳过的,然后使用保留的加载测试或预测。预计即使跳过合适的舞台,PytorchProfiler也不会 crash 。

Environment PyTorch Lightning Version : 1.6.0 PyTorch Version : 1.11.0 Python version : 3.9.0 OS : Linux & macOS CUDA/cuDNN version: irrelevant GPU models and configuration: irrelevant How you installed PyTorch (conda, pip, source): pip If compiling from source, the output of torch.__config__.show(): None Any other relevant information: - Additional context

没有任何

cc @otjay @unavarchy @ningithekloud @rohit:

该问题来自:PyTorchLightning/pytorch-lightning/issues/13034 , 试试查看该项目更多issue.

问题解答

Craig Scott 2022-5-12 16:47:49 显示全部楼层

感谢您的报告!这是一个较小的繁殖代码:

import os

import torch
from torch.utils.data import DataLoader, Dataset

from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.callbacks import ModelCheckpoint


class RandomDataset(Dataset):
    def __init__(self, size, num_samples):
        self.len = num_samples
        self.data = torch.randn(num_samples, size)

    def __getitem__(self, index):
        return self.data[index]

    def __len__(self):
        return self.len


class BoringModel(LightningModule):
    def __init__(self):
        super().__init__()
        self.layer = torch.nn.Linear(32, 2)

    def forward(self, x):
        return self.layer(x)

    def training_step(self, batch, batch_idx):
        return self(batch).sum()

    def configure_optimizers(self):
        return torch.optim.SGD(self.layer.parameters(), lr=0.1)


train_data = DataLoader(RandomDataset(32, 64), batch_size=2)
ckpt = ModelCheckpoint(save_last=True)
model = BoringModel()

trainer = Trainer(default_root_dir=os.getcwd(), limit_train_batches=1, max_epochs=1, callbacks=ckpt)
trainer.fit(model, train_dataloaders=train_data)

trainer = Trainer(
    default_root_dir=os.getcwd(),
    limit_train_batches=1,
    max_epochs=1,  # !!! max epochs equal to loaded epochs
    profiler="pytorch",  # !!! add pytorch profiler
)
trainer.fit(model, train_dataloaders=train_data, ckpt_path=ckpt.last_model_path)

Note that the second fit call is a little bit useless, because the first call will have already trained for 1 max_epoch, so the second one never trains. This uncovers the bug in our logic that's making that assertion fail.

sitifensys 2022-5-12 17:14:32 显示全部楼层

@guotuofeng我注意到您的公共公关将解决此问题。您可以考虑分开修复程序,以便我们可以将其包括在错误框架中吗?

腾讯云服务器 阿里云服务器
关注微信
^