Skip to content
This repository has been archived by the owner on Sep 18, 2024. It is now read-only.

adapt pytorch lighting 2.0 AKA lightning #5606

Open
wants to merge 9 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 19 additions & 6 deletions nni/compression/pytorch/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,16 @@
from torch.utils.hooks import RemovableHandle

try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
import lightning as pl
from lightning.callbacks import Callback
except ImportError:
LIGHTNING_INSTALLED = False
try:
import pytorch_lightning as pl
from pytorch_lightning.callbacks import Callback
except ImportError:
LIGHTNING_INSTALLED = False
else:
LIGHTNING_INSTALLED = True
else:
LIGHTNING_INSTALLED = True

Expand Down Expand Up @@ -315,10 +321,17 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
err_msg_p = (
'Only support traced {}, please use nni.trace({}) to initialize the trainer. '
'for pytorch_lightning version > 2.0, please using {}'
)
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer', 'lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
err_msg = err_msg_p.format(
'pytorch_lightning.LightningDataModule',
'pytorch_lightning.LightningDataModule',
'lightning.LightningDataModule',
)
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
Expand Down
22 changes: 17 additions & 5 deletions nni/contrib/compression/utils/evaluator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,14 @@
from torch.utils.hooks import RemovableHandle

try:
import pytorch_lightning as pl
import lightning as pl
except ImportError:
LIGHTNING_INSTALLED = False
try:
import pytorch_lightning as pl
except ImportError:
LIGHTNING_INSTALLED = False
else:
LIGHTNING_INSTALLED = True
else:
LIGHTNING_INSTALLED = True

Expand Down Expand Up @@ -370,10 +375,17 @@ class LightningEvaluator(Evaluator):
def __init__(self, trainer: pl.Trainer, data_module: pl.LightningDataModule,
dummy_input: Any | None = None):
assert LIGHTNING_INSTALLED, 'pytorch_lightning is not installed.'
err_msg_p = 'Only support traced {}, please use nni.trace({}) to initialize the trainer.'
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer')
err_msg_p = (
'Only support traced {}, please use nni.trace({}) to initialize the trainer. '
'for pytorch_lightning version > 2.0, please using {}'
)
err_msg = err_msg_p.format('pytorch_lightning.Trainer', 'pytorch_lightning.Trainer', 'lightning.Trainer')
assert isinstance(trainer, pl.Trainer) and is_traceable(trainer), err_msg
err_msg = err_msg_p.format('pytorch_lightning.LightningDataModule', 'pytorch_lightning.LightningDataModule')
err_msg = err_msg_p.format(
'pytorch_lightning.LightningDataModule',
'pytorch_lightning.LightningDataModule',
'lightning.LightningDataModule',
)
assert isinstance(data_module, pl.LightningDataModule) and is_traceable(data_module), err_msg
self.trainer = trainer
self.data_module = data_module
Expand Down