Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update Image training code for downstream datasets (iWildcam, transfer datasets) #2310

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions src/sparseml/pytorch/torchvision/presets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ def __init__(
self,
*,
crop_size,
resize_size=None,
mean=(0.485, 0.456, 0.406),
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
Expand All @@ -20,7 +21,14 @@ def __init__(
augmix_severity=3,
random_erase_prob=0.0,
):
trans = [transforms.RandomResizedCrop(crop_size, interpolation=interpolation)]
if resize_size is not None:
trans = [
transforms.Resize(resize_size, interpolation=interpolation),
]
else:
trans = [
transforms.RandomResizedCrop(crop_size, interpolation=interpolation)
]
if hflip_prob > 0:
trans.append(transforms.RandomHorizontalFlip(hflip_prob))
if auto_augment_policy is not None:
Expand Down Expand Up @@ -73,7 +81,6 @@ def __init__(
std=(0.229, 0.224, 0.225),
interpolation=InterpolationMode.BILINEAR,
):

self.transforms = transforms.Compose(
[
transforms.Resize(resize_size, interpolation=interpolation),
Expand Down
108 changes: 83 additions & 25 deletions src/sparseml/pytorch/torchvision/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
from packaging import version
from torch import nn
from torch.utils.data.dataloader import DataLoader, default_collate
from torchvision.datasets import DTD, FGVCAircraft, Flowers102 # noqa: F401
from torchvision.transforms.functional import InterpolationMode

import click
Expand Down Expand Up @@ -278,11 +279,21 @@ def _get_cache_path(filepath):


def load_data(traindir, valdir, args):
if args.transfer_dataset is not None:
if args.transfer_dataset not in ("FGVCAircraft", "DTD", "Flowers102"):
raise ValueError(
"FGVCAircraft, DTD, and Flowers102 are allowed as transfer_datasets."
)
# Data loading code
_LOGGER.info("Loading data")
val_resize_size, val_crop_size, train_crop_size = (
if args.train_resize_size is not None:
args.train_resize_size = [args.train_resize_size, args.train_resize_size]
if args.val_resize_size is not None:
args.val_resize_size = [args.val_resize_size, args.val_resize_size]
val_resize_size, val_crop_size, train_resize_size, train_crop_size = (
args.val_resize_size,
args.val_crop_size,
args.train_resize_size,
args.train_crop_size,
)
interpolation = InterpolationMode(args.interpolation)
Expand All @@ -299,19 +310,26 @@ def load_data(traindir, valdir, args):
random_erase_prob = getattr(args, "random_erase", 0.0)
ra_magnitude = args.ra_magnitude
augmix_severity = args.augmix_severity
dataset = torchvision.datasets.ImageFolder(
traindir,
presets.ClassificationPresetTrain(
crop_size=train_crop_size,
mean=args.rgb_mean,
std=args.rgb_std,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
),
train_transforms = presets.ClassificationPresetTrain(
crop_size=train_crop_size,
resize_size=train_resize_size,
mean=args.rgb_mean,
std=args.rgb_std,
interpolation=interpolation,
auto_augment_policy=auto_augment_policy,
random_erase_prob=random_erase_prob,
ra_magnitude=ra_magnitude,
augmix_severity=augmix_severity,
)
if args.transfer_dataset is None:
dataset = torchvision.datasets.ImageFolder(traindir, train_transforms)
else:
dataset = eval(args.transfer_dataset)(
root=f"/tmp/{args.transfer_dataset}",
split=args.transfer_dataset_train_split,
transform=train_transforms,
download=True,
)
if args.cache_dataset:
_LOGGER.info(f"Saving dataset_train to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
Expand All @@ -333,10 +351,18 @@ def load_data(traindir, valdir, args):
interpolation=interpolation,
)

dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
if args.transfer_dataset is None:
dataset_test = torchvision.datasets.ImageFolder(
valdir,
preprocessing,
)
else:
dataset_test = eval(args.transfer_dataset)(
root=f"/tmp/{args.transfer_dataset}",
split=args.transfer_dataset_test_split,
transform=preprocessing,
download=True,
)
if args.cache_dataset:
_LOGGER.info(f"Saving dataset_test to {cache_path}")
utils.mkdir(os.path.dirname(cache_path))
Expand Down Expand Up @@ -389,9 +415,15 @@ def main(args):
dataset, dataset_test, train_sampler, test_sampler = load_data(
train_dir, val_dir, args
)

collate_fn = None
num_classes = len(dataset.classes)
try:
num_classes = len(dataset.classes)
except: # noqa: E722
# For some reason, the classes method is not implemented for Flowers102.
if args.transfer_dataset == "Flowers102":
num_classes = 102
else:
raise ValueError(f"unknown number of classes for {args.transfer_dataset}")
mixup_transforms = []
if args.mixup_alpha > 0.0:
mixup_transforms.append(
Expand All @@ -404,7 +436,7 @@ def main(args):
if mixup_transforms:
mixupcutmix = torchvision.transforms.RandomChoice(mixup_transforms)

def collate_fn(batch):
def collate_fn(batch): # noqa: F811
return mixupcutmix(*default_collate(batch))

data_loader = torch.utils.data.DataLoader(
Expand Down Expand Up @@ -475,9 +507,9 @@ def collate_fn(batch):
model,
args.weight_decay,
norm_weight_decay=args.norm_weight_decay,
custom_keys_weight_decay=custom_keys_weight_decay
if len(custom_keys_weight_decay) > 0
else None,
custom_keys_weight_decay=(
custom_keys_weight_decay if len(custom_keys_weight_decay) > 0 else None
),
)

opt_name = args.opt.lower()
Expand Down Expand Up @@ -986,6 +1018,26 @@ def new_func(*args, **kwargs):
help="json parsable dict of recipe variable names to values to overwrite with",
)
@click.option("--dataset-path", required=True, type=str, help="dataset path")
@click.option(
"--transfer-dataset",
required=False,
type=str,
help="Dataset to be loaded using torchvision class.",
)
@click.option(
"--transfer-dataset-train-split",
required=False,
type=str,
default="train",
help="Train split name for transfer dataset",
)
@click.option(
"--transfer-dataset-test-split",
required=False,
type=str,
default="test",
help="Test split name for transfer dataset",
)
@click.option(
"--arch-key",
default=None,
Expand Down Expand Up @@ -1205,19 +1257,25 @@ def new_func(*args, **kwargs):
"--val-resize-size",
default=256,
type=int,
help="the resize size used for validation",
help="the resize size used for validation (always square)",
)
@click.option(
"--val-crop-size",
default=224,
type=int,
help="the central crop size used for validation",
)
@click.option(
"--train-resize-size",
default=None,
type=int,
help="If set, the resize size used for training (always square)",
)
@click.option(
"--train-crop-size",
default=224,
type=int,
help="the random crop size used for training",
help="the random crop size used for training (always square)",
)
@click.option(
"--clip-grad-norm",
Expand Down
Loading