diff --git a/README.md b/README.md index 82bce02..26eb5e8 100644 --- a/README.md +++ b/README.md @@ -105,21 +105,40 @@ lr_finder.reset() ### Mixed precision training -Currently, we use [`apex`](https://github.com/NVIDIA/apex) as the dependency for mixed precision training. -To enable mixed precision training, you just need to call `amp.initialize()` before running `LRFinder`. e.g. - -```python -from torch_lr_finder import LRFinder -from apex import amp - -# Add this line before running `LRFinder` -model, optimizer = amp.initialize(model, optimizer, opt_level='O1') - -lr_finder = LRFinder(model, optimizer, criterion, device='cuda') -lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') -lr_finder.plot() -lr_finder.reset() -``` +Both `apex.amp` and `torch.amp` are supported now, here are the examples: + +- Using [`apex.amp`](https://github.com/NVIDIA/apex): + ```python + from torch_lr_finder import LRFinder + from apex import amp + + # Add this line before running `LRFinder` + model, optimizer = amp.initialize(model, optimizer, opt_level='O1') + + lr_finder = LRFinder(model, optimizer, criterion, device='cuda', amp_backend='apex') + lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') + lr_finder.plot() + lr_finder.reset() + ``` + +- Using [`torch.amp`](https://pytorch.org/docs/stable/notes/amp_examples.html) + ```python + from torch_lr_finder import LRFinder + + amp_config = { + 'device_type': 'cuda', + 'dtype': torch.float16, + } + grad_scaler = torch.cuda.amp.GradScaler() + + lr_finder = LRFinder( + model, optimizer, criterion, device='cuda', + amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler + ) + lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') + lr_finder.plot() + lr_finder.reset() + ``` Note that the benefit of mixed precision training requires a nvidia GPU with tensor cores (see also: [NVIDIA/apex #297](https://github.com/NVIDIA/apex/issues/297)) diff --git a/examples/mnist_with_amp.py b/examples/mnist_with_amp.py new file mode 100644 index 0000000..0b43ca6 --- /dev/null +++ b/examples/mnist_with_amp.py @@ -0,0 +1,219 @@ +""" +Train a simple neural net for MNIST dataset with mixed precision training. + +Examples +-------- +- Run with `torch.amp`: + ```bash + $ python mnist_with_amp.py --batch_size=32 --seed=42 --tqdm --amp_backend=torch + ``` +- Run without mixed precision training: + ```bash + $ python mnist_with_amp.py --batch_size=32 --seed=42 --tqdm --amp_backend="" + ``` +""" + +from argparse import ArgumentParser +import random +import sys +import os +import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.optim as optim +from torch.utils.data import Subset, DataLoader +from torchvision import datasets, transforms + +from torch_lr_finder import LRFinder +from apex import amp + + +SEED = 0 + +def reset_seed(seed): + """ + ref: https://forums.fast.ai/t/accumulating-gradients/33219/28 + """ + random.seed(seed) + os.environ['PYTHONHASHSEED'] = str(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.backends.cudnn.deterministic = True + + +def simple_timer(func): + def wrapper(*args, **kwargs): + st = time.time() + func(*args, **kwargs) + print('--- Time taken from {}: {} seconds'.format( + func.__qualname__, time.time() - st + )) + return wrapper + + +# redirect output from tqdm +def conceal_stdout(enabled): + if enabled: + f = open(os.devnull, 'w') + sys.stdout = f + sys.stderr = f + else: + sys.stdout = sys.__stdout__ + sys.stderr = sys.__stderr__ + + +class ConvNet(nn.Module): + def __init__(self): + super(ConvNet, self).__init__() + self.conv1 = nn.Conv2d(1, 16, kernel_size=5, stride=1) + self.conv2 = nn.Conv2d(16, 32, kernel_size=3, stride=1) + self.conv2_drop = nn.Dropout2d() + self.net = nn.Sequential( + self.conv1, # (24, 24, 16) + nn.MaxPool2d(2), # (12, 12, 16) + nn.ReLU(True), + self.conv2, # (10, 10, 32) + self.conv2_drop, + nn.MaxPool2d(2), # (5, 5, 32) + nn.ReLU(True), + ) + self.fc1 = nn.Linear(5*5*32, 64) + self.fc2 = nn.Linear(64, 16) + + def forward(self, x): + x = self.net(x) + x = x.view(-1, 5*5*32) + x = F.relu(self.fc1(x)) + x = F.dropout(x, training=self.training) + x = self.fc2(x) + return F.log_softmax(x, dim=1) + + +@simple_timer +def warm_up(trainset): + trainloader = DataLoader(trainset, batch_size=256, shuffle=True) + + device = torch.device('cuda') + model = ConvNet() + model = model.to(device) + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5) + criterion = nn.NLLLoss() + + conceal_stdout(True) + lr_finder = LRFinder(model, optimizer, criterion, device='cuda') + lr_finder.range_test(trainloader, end_lr=10, num_iter=10, step_mode='exp') + conceal_stdout(False) + + +@simple_timer +def run_normal(trainset, batch_size, no_tqdm=True): + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + + device = torch.device('cuda') + model = ConvNet() + model = model.to(device) + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5) + criterion = nn.NLLLoss() + + conceal_stdout(no_tqdm) + lr_finder = LRFinder(model, optimizer, criterion, device='cuda') + lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') + lr_finder.plot() + conceal_stdout(no_tqdm and False) + + +@simple_timer +def run_amp_apex(trainset, batch_size, no_tqdm=True, opt_level='O1'): + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + + device = torch.device('cuda') + model = ConvNet() + model = model.to(device) + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5) + criterion = nn.NLLLoss() + + model, optimizer = amp.initialize(model, optimizer, opt_level=opt_level) + + conceal_stdout(no_tqdm) + lr_finder = LRFinder(model, optimizer, criterion, device='cuda', amp_backend='apex') + lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') + lr_finder.plot() + conceal_stdout(no_tqdm and False) + +@simple_timer +def run_amp_torch(trainset, batch_size, no_tqdm=True): + trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True) + + device = torch.device('cuda') + model = ConvNet() + model = model.to(device) + optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.5) + criterion = nn.NLLLoss() + + amp_config = { + 'device_type': 'cuda', + 'dtype': torch.float16, + } + grad_scaler = torch.cuda.amp.GradScaler() + + conceal_stdout(no_tqdm) + lr_finder = LRFinder( + model, optimizer, criterion, + amp_backend='torch', amp_config=amp_config, grad_scaler=grad_scaler + ) + lr_finder.range_test(trainloader, end_lr=10, num_iter=100, step_mode='exp') + lr_finder.plot() + conceal_stdout(no_tqdm and False) + +def parse_args(): + parser = ArgumentParser(add_help=True) + parser.add_argument('--amp_backend', type=str, default='', + help='Backend for auto-mixed precision training, available: ' + '[torch, apex]. If not specified, amp is disabled.') + parser.add_argument('--batch_size', type=int, default=32) + parser.add_argument('--seed', type=int, default=0, help='Random seed.') + parser.add_argument('--data_folder', type=str, default='./data', + help='Location of MNIST dataset.') + parser.add_argument('--cudnn_benchmark', action='store_true', + help='Add this flag to make cudnn auto-tuner able to find ' + 'the best algorithm on your machine. This may improve the ' + 'performance when you are running script of mixed precision ' + 'training.') + parser.add_argument('--tqdm', action='store_true', + help='Add this flag to show the output from tqdm.') + parser.add_argument('--warm_up', action='store_true', + help='Add this flag to run a warm-up snippet.') + parser.add_argument('--opt_level', type=str, default='O1', + help='Optimization level for amp. (works only for `apex`)') + return parser.parse_args() + + +if __name__ == '__main__': + args = parse_args() + + # turn this mode on may improve the performance on some GPUs + torch.backends.cudnn.benchmark = args.cudnn_benchmark + + transform = transforms.Compose([ + transforms.ToTensor(), + transforms.Normalize((0.1307,), (0.3081,)) + ]) + trainset = datasets.MNIST(args.data_folder, train=True, download=True, transform=transform) + + reset_seed(args.seed) + if args.warm_up: + warm_up(trainset) + + if args.amp_backend == '': + run_normal(trainset, args.batch_size, no_tqdm=not args.tqdm) + elif args.amp_backend == 'apex': + run_amp_apex(trainset, args.batch_size, no_tqdm=not args.tqdm, opt_level=args.opt_level) + elif args.amp_backend == 'torch': + run_amp_torch(trainset, args.batch_size, no_tqdm=not args.tqdm) + else: + print('Unknown amp backend: {}'.format(args.amp_backend)) + diff --git a/tests/test_lr_finder.py b/tests/test_lr_finder.py index 150671c..4577f9a 100644 --- a/tests/test_lr_finder.py +++ b/tests/test_lr_finder.py @@ -10,13 +10,19 @@ import matplotlib.pyplot as plt - +# Check available backends for mixed precision training +AVAILABLE_AMP_BACKENDS = [] try: - from apex import amp + import apex.amp + AVAILABLE_AMP_BACKENDS.append("apex") +except ImportError: + pass - IS_AMP_AVAILABLE = True +try: + import torch.amp + AVAILABLE_AMP_BACKENDS.append("torch") except ImportError: - IS_AMP_AVAILABLE = False + pass def collect_task_classes(): @@ -34,6 +40,9 @@ def prepare_lr_finder(task, **kwargs): "device": kwargs.get("device", None), "memory_cache": kwargs.get("memory_cache", True), "cache_dir": kwargs.get("cache_dir", None), + "amp_backend": kwargs.get("amp_backend", None), + "amp_config": kwargs.get("amp_config", None), + "grad_scaler": kwargs.get("grad_scaler", None), } lr_finder = LRFinder(model, optimizer, criterion, **config) return lr_finder @@ -173,7 +182,7 @@ def test_gradient_accumulation(self, mocker): assert spy.call_count == accum_steps * num_iter @pytest.mark.skipif( - not (IS_AMP_AVAILABLE and mod_task.use_cuda()), + not (("apex" in AVAILABLE_AMP_BACKENDS) and mod_task.use_cuda()), reason="`apex` module and gpu is required to run this test." ) def test_gradient_accumulation_with_apex_amp(self, mocker): @@ -186,23 +195,50 @@ def test_gradient_accumulation_with_apex_amp(self, mocker): # CUDA GPU. So we have to move model to GPU first. model, optimizer, device = task.model, task.optimizer, task.device model = model.to(device) - task.model, task.optimizer = amp.initialize(model, optimizer) + task.model, task.optimizer = apex.amp.initialize(model, optimizer) - lr_finder = prepare_lr_finder(task) - spy = mocker.spy(amp, "scale_loss") + lr_finder = prepare_lr_finder(task, amp_backend="apex") + spy = mocker.spy(apex.amp, "scale_loss") lr_finder.range_test( task.train_loader, num_iter=num_iter, accumulation_steps=accum_steps ) assert spy.call_count == accum_steps * num_iter + @pytest.mark.skipif( + not (("torch" in AVAILABLE_AMP_BACKENDS) and mod_task.use_cuda()), + reason="`torch.amp` module and gpu is required to run this test." + ) + def test_gradient_accumulation_with_torch_amp(self, mocker): + desired_bs, accum_steps = 32, 4 + real_bs = desired_bs // accum_steps + num_iter = 10 + task = mod_task.XORTask(batch_size=real_bs) + + # Config for `torch.amp`. Though `torch.amp.autocast` supports various + # device types, we test it with CUDA only. + amp_config = { + "device_type": "cuda", + "dtype": torch.float16, + } + grad_scaler = torch.cuda.amp.GradScaler() + + lr_finder = prepare_lr_finder( + task, amp_backend="torch", amp_config=amp_config, grad_scaler=grad_scaler + ) + spy = mocker.spy(grad_scaler, "scale") + + lr_finder.range_test( + task.train_loader, num_iter=num_iter, accumulation_steps=accum_steps + ) + assert spy.call_count == accum_steps * num_iter @pytest.mark.skipif( - not (IS_AMP_AVAILABLE and mod_task.use_cuda()), + not (("apex" in AVAILABLE_AMP_BACKENDS) and mod_task.use_cuda()), reason="`apex` module and gpu is required to run these tests." ) class TestMixedPrecision: - def test_mixed_precision(self, mocker): + def test_mixed_precision_apex(self, mocker): batch_size = 32 num_iter = 10 task = mod_task.XORTask(batch_size=batch_size) @@ -211,17 +247,37 @@ def test_mixed_precision(self, mocker): # CUDA GPU. So we have to move model to GPU first. model, optimizer, device = task.model, task.optimizer, task.device model = model.to(device) - task.model, task.optimizer = amp.initialize(model, optimizer) + task.model, task.optimizer = apex.amp.initialize(model, optimizer) assert hasattr(task.optimizer, "_amp_stash") - lr_finder = prepare_lr_finder(task) - spy = mocker.spy(amp, "scale_loss") + lr_finder = prepare_lr_finder(task, amp_backend="apex") + spy = mocker.spy(apex.amp, "scale_loss") lr_finder.range_test(task.train_loader, num_iter=num_iter) # NOTE: Here we did not perform gradient accumulation, so that call count # of `amp.scale_loss` should equal to `num_iter`. assert spy.call_count == num_iter + def test_mixed_precision_torch(self, mocker): + batch_size = 32 + num_iter = 10 + task = mod_task.XORTask(batch_size=batch_size) + + amp_config = { + "device_type": "cuda", + "dtype": torch.float16, + } + grad_scaler = torch.cuda.amp.GradScaler() + + lr_finder = prepare_lr_finder( + task, amp_backend="torch", amp_config=amp_config, grad_scaler=grad_scaler + ) + spy = mocker.spy(grad_scaler, "scale") + + lr_finder.range_test(task.train_loader, num_iter=num_iter) + # NOTE: Here we did not perform gradient accumulation, so that call count + # of `amp.scale_loss` should equal to `num_iter`. + assert spy.call_count == num_iter class TestDataLoaderIter: def test_traindataloaderiter(self): diff --git a/torch_lr_finder/lr_finder.py b/torch_lr_finder/lr_finder.py index 8311bf5..629f71b 100644 --- a/torch_lr_finder/lr_finder.py +++ b/torch_lr_finder/lr_finder.py @@ -11,12 +11,19 @@ PYTORCH_VERSION = version.parse(torch.__version__) +# Check available backends for mixed precision training +AVAILABLE_AMP_BACKENDS = [] try: - from apex import amp + import apex.amp + AVAILABLE_AMP_BACKENDS.append("apex") +except ImportError: + pass - IS_AMP_AVAILABLE = True +try: + import torch.amp + AVAILABLE_AMP_BACKENDS.append("torch") except ImportError: - IS_AMP_AVAILABLE = False + pass class DataLoaderIter(object): @@ -127,6 +134,12 @@ class LRFinder(object): cache_dir (string, optional): path for storing temporary files. If no path is specified, system-wide temporary directory is used. Notice that this parameter will be ignored if `memory_cache` is True. + amp_backend (string, optional): backend for mixed precision training. Currently + only `torch.amp` and `apex.amp` are supported. If it's not specified, + mixed precision training is disabled. + amp_config (dict, optional): config for `torch.amp.autocast()` only. + grad_scaler (torch.cuda.amp.GradScaler, optional): gradient scaler for + `torch.amp` only. Example: >>> lr_finder = LRFinder(net, optimizer, criterion, device="cuda") @@ -147,6 +160,9 @@ def __init__( device=None, memory_cache=True, cache_dir=None, + amp_backend=None, + amp_config=None, + grad_scaler=None, ): # Check if the optimizer is already attached to a scheduler self.optimizer = optimizer @@ -159,6 +175,18 @@ def __init__( self.memory_cache = memory_cache self.cache_dir = cache_dir + # Settings related to mixed precision training + if amp_backend and (amp_backend not in AVAILABLE_AMP_BACKENDS): + raise ValueError("Unknown amp backend: {}".format(amp_backend)) + + if amp_backend == "torch": + if grad_scaler is None: + raise ValueError("`grad_scaler` is required when using `torch.amp`") + + self.amp_backend = amp_backend + self.amp_config = amp_config + self.grad_scaler = grad_scaler + # Save the original state of the model and optimizer so they can be restored if # needed self.model_device = next(self.model.parameters()).device @@ -374,21 +402,22 @@ def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=Tru ) # Forward pass - outputs = self.model(inputs) - loss = self.criterion(outputs, labels) + if self.amp_backend == "torch": + with torch.amp.autocast(**self.amp_config): + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) + else: + outputs = self.model(inputs) + loss = self.criterion(outputs, labels) # Loss should be averaged in each step loss /= accumulation_steps # Backward pass - if IS_AMP_AVAILABLE and hasattr(self.optimizer, "_amp_stash"): - # For minor performance optimization, see also: - # https://nvidia.github.io/apex/advanced.html#gradient-accumulation-across-iterations - delay_unscale = ((i + 1) % accumulation_steps) != 0 - - with amp.scale_loss( - loss, self.optimizer, delay_unscale=delay_unscale - ) as scaled_loss: + if self.amp_backend == "torch": + self.grad_scaler.scale(loss).backward() + elif self.amp_backend == "apex" and hasattr(self.optimizer, "_amp_stash"): + with apex.amp.scale_loss(loss, self.optimizer) as scaled_loss: scaled_loss.backward() else: loss.backward() @@ -398,7 +427,11 @@ def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=Tru else: total_loss += loss - self.optimizer.step() + if self.amp_backend == "torch": + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + self.optimizer.step() return total_loss.item()