From aeff47e699be436715894fccc0f54444134607a0 Mon Sep 17 00:00:00 2001 From: Yufan He <59374597+heyufan1995@users.noreply.github.com> Date: Wed, 11 Sep 2024 13:46:07 -0500 Subject: [PATCH] Fix model weight load bug with multigpu (#40) Fixes # . ### Description Fix the bug where model weights are not loaded when multigpu is used. ### Types of changes - [x] Non-breaking change (fix or new feature that would not break existing functionality). - [ ] Breaking change (fix or new feature that would cause existing functionality to change). - [ ] New tests added to cover the changes. - [ ] In-line docstrings updated. --------- Signed-off-by: heyufan1995 Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> --- vista3d/scripts/train.py | 14 +++++--------- vista3d/scripts/train_finetune.py | 13 +++++-------- 2 files changed, 10 insertions(+), 17 deletions(-) diff --git a/vista3d/scripts/train.py b/vista3d/scripts/train.py index e9beb8e..0c89d16 100644 --- a/vista3d/scripts/train.py +++ b/vista3d/scripts/train.py @@ -40,7 +40,6 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import DataLoader, DistributedSampler, DistributedWeightedRandomSampler from monai.metrics import compute_dice -from monai.networks.utils import copy_model_state from monai.utils import optional_import, set_determinism from torch.nn.parallel import DistributedDataParallel from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler @@ -216,10 +215,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -229,13 +224,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") - + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter") diff --git a/vista3d/scripts/train_finetune.py b/vista3d/scripts/train_finetune.py index bb945ee..02a5c15 100644 --- a/vista3d/scripts/train_finetune.py +++ b/vista3d/scripts/train_finetune.py @@ -35,7 +35,6 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import DataLoader, DistributedSampler from monai.metrics import compute_dice -from monai.networks.utils import copy_model_state from monai.utils import set_determinism from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter @@ -149,10 +148,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -162,13 +157,15 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")