diff --git a/vista3d/scripts/train.py b/vista3d/scripts/train.py index e9beb8e..0c89d16 100644 --- a/vista3d/scripts/train.py +++ b/vista3d/scripts/train.py @@ -40,7 +40,6 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import DataLoader, DistributedSampler, DistributedWeightedRandomSampler from monai.metrics import compute_dice -from monai.networks.utils import copy_model_state from monai.utils import optional_import, set_determinism from torch.nn.parallel import DistributedDataParallel from torch.utils.data.sampler import RandomSampler, WeightedRandomSampler @@ -216,10 +215,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -229,13 +224,14 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") - + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter") diff --git a/vista3d/scripts/train_finetune.py b/vista3d/scripts/train_finetune.py index bb945ee..02a5c15 100644 --- a/vista3d/scripts/train_finetune.py +++ b/vista3d/scripts/train_finetune.py @@ -35,7 +35,6 @@ from monai.bundle.scripts import _pop_args, _update_args from monai.data import DataLoader, DistributedSampler from monai.metrics import compute_dice -from monai.networks.utils import copy_model_state from monai.utils import set_determinism from torch.nn.parallel import DistributedDataParallel from torch.utils.tensorboard import SummaryWriter @@ -149,10 +148,6 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): optimizer = optimizer_part.instantiate(params=model.parameters()) lr_scheduler_part = parser.get_parsed_content("lr_scheduler", instantiate=False) lr_scheduler = lr_scheduler_part.instantiate(optimizer=optimizer) - if world_size > 1: - model = DistributedDataParallel( - model, device_ids=[device], find_unused_parameters=True - ) if finetune["activate"] and os.path.isfile(finetune["pretrained_ckpt_name"]): logger.debug( "Fine-tuning pre-trained checkpoint {:s}".format( @@ -162,13 +157,15 @@ def run(config_file: Optional[Union[str, Sequence[str]]] = None, **override): pretrained_ckpt = torch.load( finetune["pretrained_ckpt_name"], map_location=device ) - copy_model_state( - model, pretrained_ckpt, exclude_vars=finetune.get("exclude_vars") - ) + model.load_state_dict(pretrained_ckpt) del pretrained_ckpt else: logger.debug("Training from scratch") + if world_size > 1: + model = DistributedDataParallel( + model, device_ids=[device], find_unused_parameters=True + ) # training hyperparameters - sample num_images_per_batch = parser.get_parsed_content("num_images_per_batch") num_patches_per_iter = parser.get_parsed_content("num_patches_per_iter")