From aee808c2dda0d5dec75ddb47718c52e14c21f42b Mon Sep 17 00:00:00 2001 From: Eugenia Iofinova Date: Tue, 28 May 2024 08:42:09 +0200 Subject: [PATCH 1/4] revise dataloaders to accomodate different datasets and transforms --- src/sparseml/pytorch/torchvision/presets.py | 11 +- src/sparseml/pytorch/torchvision/train.py | 122 ++++++++++++++++---- 2 files changed, 106 insertions(+), 27 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/presets.py b/src/sparseml/pytorch/torchvision/presets.py index e5001679f01..decd259fc65 100644 --- a/src/sparseml/pytorch/torchvision/presets.py +++ b/src/sparseml/pytorch/torchvision/presets.py @@ -11,6 +11,7 @@ def __init__( self, *, crop_size, + resize_size=None, mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, @@ -20,7 +21,14 @@ def __init__( augmix_severity=3, random_erase_prob=0.0, ): - trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)] + if resize_size is not None: + trans = [ + transforms.Resize(resize_size, interpolation=interpolation), + ] + else: + trans = [ + transforms.RandomResizedCrop(crop_size, interpolation=interpolation) + ] if hflip_prob > 0: trans.append(transforms.RandomHorizontalFlip(hflip_prob)) if auto_augment_policy is not None: @@ -73,7 +81,6 @@ def __init__( std=(0.229, 0.224, 0.225), interpolation=InterpolationMode.BILINEAR, ): - self.transforms = transforms.Compose( [ transforms.Resize(resize_size, interpolation=interpolation), diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 70ac270299e..12efe725a0e 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -42,6 +42,7 @@ from packaging import version from torch import nn from torch.utils.data.dataloader import DataLoader, default_collate +from torchvision.datasets import DTD, FGVCAircraft, Flowers102 # noqa: F401 from torchvision.transforms.functional import InterpolationMode import click @@ -278,11 +279,25 @@ def _get_cache_path(filepath): def load_data(traindir, valdir, args): + if args.transfer_dataset is not None: + if args.transfer_dataset not in ("FGVCAircraft", "DTD", "Flowers102"): + raise ValueError( + "FGVCAircraft, DTD, and Flowers102 are allowed as transfer_datasets." + ) # Data loading code _LOGGER.info("Loading data") - val_resize_size, val_crop_size, train_crop_size = ( + if len(args.train_resize_size) > 2: + raise ValueError("--train-resize-size must be of length 1 or 2") + if args.train_resize_size[0] is None: + args.train_resize_size = None + if len(args.val_resize_size) > 2: + raise ValueError("--val-resize-size must be of length 1 or 2") + if args.val_resize_size[0] is None: + args.val_resize_size = None + val_resize_size, val_crop_size, train_resize_size, train_crop_size = ( args.val_resize_size, args.val_crop_size, + args.train_resize_size, args.train_crop_size, ) interpolation = InterpolationMode(args.interpolation) @@ -299,19 +314,26 @@ def load_data(traindir, valdir, args): random_erase_prob = getattr(args, "random_erase", 0.0) ra_magnitude = args.ra_magnitude augmix_severity = args.augmix_severity - dataset = torchvision.datasets.ImageFolder( - traindir, - presets.ClassificationPresetTrain( - crop_size=train_crop_size, - mean=args.rgb_mean, - std=args.rgb_std, - interpolation=interpolation, - auto_augment_policy=auto_augment_policy, - random_erase_prob=random_erase_prob, - ra_magnitude=ra_magnitude, - augmix_severity=augmix_severity, - ), + train_transforms = presets.ClassificationPresetTrain( + crop_size=train_crop_size, + resize_size=train_resize_size, + mean=args.rgb_mean, + std=args.rgb_std, + interpolation=interpolation, + auto_augment_policy=auto_augment_policy, + random_erase_prob=random_erase_prob, + ra_magnitude=ra_magnitude, + augmix_severity=augmix_severity, ) + if args.transfer_dataset is None: + dataset = torchvision.datasets.ImageFolder(traindir, train_transforms) + else: + dataset = eval(args.transfer_dataset)( + root=f"/tmp/{args.transfer_dataset}", + split=args.transfer_dataset_train_split, + transform=train_transforms, + download=True, + ) if args.cache_dataset: _LOGGER.info(f"Saving dataset_train to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) @@ -333,10 +355,18 @@ def load_data(traindir, valdir, args): interpolation=interpolation, ) - dataset_test = torchvision.datasets.ImageFolder( - valdir, - preprocessing, - ) + if args.transfer_dataset is None: + dataset_test = torchvision.datasets.ImageFolder( + valdir, + preprocessing, + ) + else: + dataset_test = eval(args.transfer_dataset)( + root=f"/tmp/{args.transfer_dataset}", + split=args.transfer_dataset_test_split, + transform=preprocessing, + download=True, + ) if args.cache_dataset: _LOGGER.info(f"Saving dataset_test to {cache_path}") utils.mkdir(os.path.dirname(cache_path)) @@ -389,9 +419,15 @@ def main(args): dataset, dataset_test, train_sampler, test_sampler = load_data( train_dir, val_dir, args ) - collate_fn = None - num_classes = len(dataset.classes) + try: + num_classes = len(dataset.classes) + except: # noqa: E722 + # For some reason, the classes method is not implemented for Flowers102. + if args.transfer_dataset == "Flowers102": + num_classes = 102 + else: + raise ValueError(f"unknown number of classes for {args.transfer_dataset}") mixup_transforms = [] if args.mixup_alpha > 0.0: mixup_transforms.append( @@ -404,7 +440,7 @@ def main(args): if mixup_transforms: mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms) - def collate_fn(batch): + def collate_fn(batch): # noqa: F811 return mixupcutmix(*default_collate(batch)) data_loader = torch.utils.data.DataLoader( @@ -475,9 +511,9 @@ def collate_fn(batch): model, args.weight_decay, norm_weight_decay=args.norm_weight_decay, - custom_keys_weight_decay=custom_keys_weight_decay - if len(custom_keys_weight_decay) > 0 - else None, + custom_keys_weight_decay=( + custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None + ), ) opt_name = args.opt.lower() @@ -743,7 +779,9 @@ def data_loader_builder(**kwargs): lr_scheduler.step() eval_metrics = evaluate(model, criterion, data_loader_test, device) - log_metrics("Test", eval_metrics, epoch, steps_per_epoch) + log_metrics( + args.transfer_dataset_test_split, eval_metrics, epoch, steps_per_epoch + ) top1_acc = eval_metrics.acc1.global_avg if model_ema: @@ -986,6 +1024,26 @@ def new_func(*args, **kwargs): help="json parsable dict of recipe variable names to values to overwrite with", ) @click.option("--dataset-path", required=True, type=str, help="dataset path") +@click.option( + "--transfer-dataset", + required=False, + type=str, + help="Dataset to be loaded using torchvision class.", +) +@click.option( + "--transfer-dataset-train-split", + required=False, + type=str, + default="train", + help="Train split name for transfer dataset", +) +@click.option( + "--transfer-dataset-test-split", + required=False, + type=str, + default="test", + help="Test split name for transfer dataset", +) @click.option( "--arch-key", default=None, @@ -1203,7 +1261,8 @@ def new_func(*args, **kwargs): ) @click.option( "--val-resize-size", - default=256, + default=[256, 256], + nargs=2, type=int, help="the resize size used for validation", ) @@ -1213,6 +1272,19 @@ def new_func(*args, **kwargs): type=int, help="the central crop size used for validation", ) +@click.option( + "--resize-square", + is_flag=True, + default=False, + help="whether to resize images to a square", +) +@click.option( + "--train-resize-size", + default=[None, None], + nargs=2, + type=int, + help="If set, the resize size used for training", +) @click.option( "--train-crop-size", default=224, From 4a9f382ef67d5757d605e080eea1e679b570a085 Mon Sep 17 00:00:00 2001 From: Jen Iofinova Date: Tue, 28 May 2024 10:18:17 -0700 Subject: [PATCH 2/4] Update train.py fix accidental search/replace error --- src/sparseml/pytorch/torchvision/train.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 12efe725a0e..f3393d93d6b 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -779,9 +779,7 @@ def data_loader_builder(**kwargs): lr_scheduler.step() eval_metrics = evaluate(model, criterion, data_loader_test, device) - log_metrics( - args.transfer_dataset_test_split, eval_metrics, epoch, steps_per_epoch - ) + log_metrics("Test", eval_metrics, epoch, steps_per_epoch) top1_acc = eval_metrics.acc1.global_avg if model_ema: From ae0c7f76b86efdde33759f5787908f1e1024e3d3 Mon Sep 17 00:00:00 2001 From: Jen Iofinova Date: Tue, 28 May 2024 10:19:21 -0700 Subject: [PATCH 3/4] Update train.py remove unused flag. --- src/sparseml/pytorch/torchvision/train.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index f3393d93d6b..9d4393e7ee2 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -1270,12 +1270,6 @@ def new_func(*args, **kwargs): type=int, help="the central crop size used for validation", ) -@click.option( - "--resize-square", - is_flag=True, - default=False, - help="whether to resize images to a square", -) @click.option( "--train-resize-size", default=[None, None], From 2700feda275f92cc6d89265d44673bebf2d89953 Mon Sep 17 00:00:00 2001 From: Eugenia Iofinova Date: Wed, 29 May 2024 19:13:45 +0200 Subject: [PATCH 4/4] simplify resize arguments --- src/sparseml/pytorch/torchvision/train.py | 24 +++++++++-------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/src/sparseml/pytorch/torchvision/train.py b/src/sparseml/pytorch/torchvision/train.py index 9d4393e7ee2..8eb080a8534 100644 --- a/src/sparseml/pytorch/torchvision/train.py +++ b/src/sparseml/pytorch/torchvision/train.py @@ -286,14 +286,10 @@ def load_data(traindir, valdir, args): ) # Data loading code _LOGGER.info("Loading data") - if len(args.train_resize_size) > 2: - raise ValueError("--train-resize-size must be of length 1 or 2") - if args.train_resize_size[0] is None: - args.train_resize_size = None - if len(args.val_resize_size) > 2: - raise ValueError("--val-resize-size must be of length 1 or 2") - if args.val_resize_size[0] is None: - args.val_resize_size = None + if args.train_resize_size is not None: + args.train_resize_size = [args.train_resize_size, args.train_resize_size] + if args.val_resize_size is not None: + args.val_resize_size = [args.val_resize_size, args.val_resize_size] val_resize_size, val_crop_size, train_resize_size, train_crop_size = ( args.val_resize_size, args.val_crop_size, @@ -1259,10 +1255,9 @@ def new_func(*args, **kwargs): ) @click.option( "--val-resize-size", - default=[256, 256], - nargs=2, + default=256, type=int, - help="the resize size used for validation", + help="the resize size used for validation (always square)", ) @click.option( "--val-crop-size", @@ -1272,16 +1267,15 @@ def new_func(*args, **kwargs): ) @click.option( "--train-resize-size", - default=[None, None], - nargs=2, + default=None, type=int, - help="If set, the resize size used for training", + help="If set, the resize size used for training (always square)", ) @click.option( "--train-crop-size", default=224, type=int, - help="the random crop size used for training", + help="the random crop size used for training (always square)", ) @click.option( "--clip-grad-norm",