diff --git a/.flake8 b/.flake8 index 8dd399a..8e515a8 100644 --- a/.flake8 +++ b/.flake8 @@ -1,3 +1,4 @@ [flake8] max-line-length = 88 extend-ignore = E203 +per-file-ignores = __init__.py:F401 diff --git a/README.md b/README.md index 45855c4..ef5d563 100644 --- a/README.md +++ b/README.md @@ -70,6 +70,7 @@ lr_finder.reset() - `LRFinder.range_test()` will change the model weights and the optimizer parameters. Both can be restored to their initial state with `LRFinder.reset()`. - The learning rate and loss history can be accessed through `lr_finder.history`. This will return a dictionary with `lr` and `loss` keys. - When using `step_mode="linear"` the learning rate range should be within the same order of magnitude. +- `LRFinder.range_test()` expects a pair of `input, label` to be returned from the `DataLoader` objects passed to it. The `input` must be ready to be passed to the model and the `label` must be ready to be passed to the `criterion` without any further data processing/handling/conversion. If you find yourself needing a workaround you can make use of the classes `TrainDataLoaderIter` and `ValDataLoaderIter` to perform any data processing/handling/conversion inbetween the `DataLoader` and the training/evaluation loop. You can find an example of how to use these classes in [examples/lrfinder_cifar10_dataloader_iter](examples/lrfinder_cifar10_dataloader_iter.ipynb). ## Additional support for training diff --git a/examples/lrfinder_cifar10_daloader_iter.ipynb b/examples/lrfinder_cifar10_daloader_iter.ipynb new file mode 100644 index 0000000..ffd7c1e --- /dev/null +++ b/examples/lrfinder_cifar10_daloader_iter.ipynb @@ -0,0 +1,455 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# CIFAR10: TrainDataLoaderIter and ValDataLoaderIter\n", + "\n", + "This example demonstrates the usage of `TrainDataLoaderIter` and `ValDataLoaderIter` with a ResNet56 on the Cifar10 dataset.\n", + "\n", + "\n", + "`LRFinder.range_test()` assumes that the `DataLoader` objects passed in return something `input, label, *` where `input` is the tensor passed to the `model` and `label` is the tensor passed to the `criterion`. If this is not the case for your application then you'll probably need to use `TrainDataLoaderIter` and `ValDataLoaderIter`. These two classes essentially wrap around the `DataLoader` class for the training set and validation set, respectively, and allow you to customize how the batches are formed for each of them.\n", + "\n", + "\n", + "Common use-cases for `TrainDataLoaderIter` and `ValDataLoaderIter`:\n", + "1. Your `DataLoader`/`Dataset.__get_item__()` returns a `dict` or other containers that differ from the above (`input, label, *`)\n", + "2. Your `Dataset.__get_item__()` doesn't perform all the data processing/handling/conversion needed for `input` and `label` to be ready to be passed in to the `model` and `criterion`, respectively." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "%matplotlib inline\n", + "\n", + "import torch.nn as nn\n", + "import torch.optim as optim\n", + "import torchvision.transforms as transforms\n", + "from torch.utils.data import DataLoader\n", + "from torchvision.datasets import CIFAR10\n", + "import cifar10_resnet as rc10\n", + "from PIL import Image\n", + "\n", + "try:\n", + " from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter\n", + "except ImportError:\n", + " # Run from source\n", + " import sys\n", + " sys.path.insert(0, '..')\n", + " from torch_lr_finder import LRFinder, TrainDataLoaderIter, ValDataLoaderIter" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Loading CIFAR10" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "To demonstrate how `TrainDataLoaderIter` and `ValDataLoaderIter` can be used, we'll create a new CIFAR10 dataset that instead of returning `img, target` like [the one from torchvision](https://github.com/pytorch/vision/blob/master/torchvision/datasets/cifar.py#L118), returns a dictionary. We are essentially provoking use-case 1 from above." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "# All we need to do here is inherit from CIFAR10 and change it's __get_item__()\n", + "# method to return a dictionary\n", + "class CIFAR10WithDict(CIFAR10):\n", + " def __getitem__(self, index):\n", + " # IMPORTANT; this code works with torchvision 0.6.0; it's not guaranteed to\n", + " # work with other versions\n", + " img, target = self.data[index], self.targets[index]\n", + "\n", + " # doing this so that it is consistent with all other datasets\n", + " # to return a PIL Image\n", + " img = Image.fromarray(img)\n", + "\n", + " if self.transform is not None:\n", + " img = self.transform(img)\n", + "\n", + " if self.target_transform is not None:\n", + " target = self.target_transform(target)\n", + " \n", + " ret_dict = {}\n", + " ret_dict[\"img\"] = img\n", + " ret_dict[\"target\"] = target\n", + "\n", + " return ret_dict" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "cifar_pwd = \"../data\"\n", + "batch_size= 256" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Files already downloaded and verified\n", + "Files already downloaded and verified\n" + ] + } + ], + "source": [ + "transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])\n", + "\n", + "trainset = CIFAR10WithDict(root=cifar_pwd, train=True, download=True, transform=transform)\n", + "trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, num_workers=0)\n", + "\n", + "testset = CIFAR10WithDict(root=cifar_pwd, train=False, download=True, transform=transform)\n", + "testloader = DataLoader(testset, batch_size=batch_size * 2, shuffle=False, num_workers=0)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Model" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "model = rc10.resnet56()" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Training loss (fastai)\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = nn.CrossEntropyLoss()\n", + "optimizer = optim.Adam(model.parameters(), lr=1e-7, weight_decay=1e-2)\n", + "lr_finder = LRFinder(model, optimizer, criterion, device=\"cuda\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Let's try to do a `range_test` and see what happens:" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c907007fbc1b42549cf7cd4dc720e3b3", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n" + ] + }, + { + "ename": "ValueError", + "evalue": "Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0maccumulation_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 290\u001b[0;31m \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 291\u001b[0m )\n\u001b[1;32m 292\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, train_iter, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 340\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccumulation_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 342\u001b[0m inputs, labels = self._move_to_device(\n\u001b[1;32m 343\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_reset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m )\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method." + ] + } + ], + "source": [ + "lr_finder.range_test(trainloader, end_lr=100, num_iter=100, step_mode=\"exp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The problem is that `LRFinder` doesn't know how to build a batch using a dictionary. This is where `TrainDataLoaderIter` comes in:" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomTrainIter(TrainDataLoaderIter):\n", + " def inputs_labels_from_batch(self, batch_data):\n", + " return batch_data[\"img\"], batch_data[\"target\"]\n", + " \n", + "custom_train_iter = CustomTrainIter(trainloader)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "`inputs_labels_from_batch()` gets the return of `CIFAR10WithDict.__get_item__()` and must return a tuple `input, label` or in this case `img, target`." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "3b266207edf64bf1a01331664717384e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Stopping early, the loss has diverged\n", + "Learning rate search finished. See the graph with {finder_name}.plot()\n" + ] + } + ], + "source": [ + "lr_finder.reset()\n", + "lr_finder.range_test(custom_train_iter, end_lr=100, num_iter=100, step_mode=\"exp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Validation loss (Leslie N. Smith)\n", + "\n", + "We find something similar when using the validation loss test. If we try to run `range_test` with `trainloader` and `testloader` we get an error:" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "5e1f7e37a29648dbbfef1dababa98e01", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "ValueError", + "evalue": "Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrainloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtestloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 288\u001b[0m \u001b[0mtrain_iter\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 289\u001b[0m \u001b[0maccumulation_steps\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 290\u001b[0;31m \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 291\u001b[0m )\n\u001b[1;32m 292\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_train_batch\u001b[0;34m(self, train_iter, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 339\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0moptimizer\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mzero_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 340\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mrange\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0maccumulation_steps\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 341\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtrain_iter\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 342\u001b[0m inputs, labels = self._move_to_device(\n\u001b[1;32m 343\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 56\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 57\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 58\u001b[0;31m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 59\u001b[0m \u001b[0;32mexcept\u001b[0m \u001b[0mStopIteration\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 60\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mauto_reset\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m )\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method." + ] + } + ], + "source": [ + "lr_finder.reset()\n", + "lr_finder.range_test(trainloader, val_loader=testloader, end_lr=100, num_iter=100, step_mode=\"exp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Even if we replace `trainloader` by `custom_train_iter` we'll still get the error:" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": { + "scrolled": true + }, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "25f4e7c42bc2481eb13cc0900015c808", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "ename": "ValueError", + "evalue": "Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method.", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 2\u001b[0;31m \u001b[0mlr_finder\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mrange_test\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcustom_train_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mtestloader\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mend_lr\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnum_iter\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m100\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstep_mode\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m\"exp\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36mrange_test\u001b[0;34m(self, train_loader, val_loader, start_lr, end_lr, num_iter, step_mode, smooth_f, diverge_th, accumulation_steps, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 292\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mval_loader\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 293\u001b[0m loss = self._validate(\n\u001b[0;32m--> 294\u001b[0;31m \u001b[0mval_iter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mnon_blocking_transfer\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 295\u001b[0m )\n\u001b[1;32m 296\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m_validate\u001b[0;34m(self, val_iter, non_blocking_transfer)\u001b[0m\n\u001b[1;32m 395\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodel\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0meval\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 396\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mno_grad\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 397\u001b[0;31m \u001b[0;32mfor\u001b[0m \u001b[0minputs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mlabels\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mval_iter\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 398\u001b[0m \u001b[0;31m# Move data to the correct device\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 399\u001b[0m inputs, labels = self._move_to_device(\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36m__next__\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 45\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m__next__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 46\u001b[0m \u001b[0mbatch\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_iterator\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 47\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0minputs_labels_from_batch\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 48\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 49\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;32m~/datascience/pytorch/pytorch-lr-finder/torch_lr_finder/lr_finder.py\u001b[0m in \u001b[0;36minputs_labels_from_batch\u001b[0;34m(self, batch_data)\u001b[0m\n\u001b[1;32m 33\u001b[0m \u001b[0;34m\"Your batch type not supported: {}. Please inherit from \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 34\u001b[0m \u001b[0;34m\"`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine \"\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 35\u001b[0;31m \u001b[0;34m\"`_batch_make_inputs_labels` method.\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mformat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtype\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_data\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 36\u001b[0m )\n\u001b[1;32m 37\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mValueError\u001b[0m: Your batch type not supported: . Please inherit from `TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine `_batch_make_inputs_labels` method." + ] + } + ], + "source": [ + "lr_finder.reset()\n", + "lr_finder.range_test(custom_train_iter, val_loader=testloader, end_lr=100, num_iter=100, step_mode=\"exp\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The solution is to wrap `testloader` in `ValDataLoaderIter` and overriding `inputs_labels_from_batch()`:" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [], + "source": [ + "class CustomValIter(ValDataLoaderIter):\n", + " def inputs_labels_from_batch(self, batch_data):\n", + " return batch_data[\"img\"], batch_data[\"target\"]\n", + " \n", + "custom_val_iter = CustomValIter(testloader)" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "bed059b7f2834752a52fa1f130cfbd59", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "HBox(children=(IntProgress(value=0), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Learning rate search finished. See the graph with {finder_name}.plot()\n" + ] + } + ], + "source": [ + "lr_finder.reset()\n", + "lr_finder.range_test(custom_train_iter, val_loader=custom_val_iter, end_lr=100, num_iter=100, step_mode=\"exp\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "lr_finder_p3", + "language": "python", + "name": "lr_finder_p3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/torch_lr_finder/lr_finder.py b/torch_lr_finder/lr_finder.py index 38c179c..23feecc 100644 --- a/torch_lr_finder/lr_finder.py +++ b/torch_lr_finder/lr_finder.py @@ -29,9 +29,11 @@ def dataset(self): def inputs_labels_from_batch(self, batch_data): if not isinstance(batch_data, list) and not isinstance(batch_data, tuple): - raise ValueError("""Your batch type not supported: {}. Please inherit from `TrainDataLoaderIter` - (or `ValDataLoaderIter`) and redefine - `_batch_make_inputs_labels` method.""".format(type(batch_data))) + raise ValueError( + "Your batch type not supported: {}. Please inherit from " + "`TrainDataLoaderIter` (or `ValDataLoaderIter`) and redefine " + "`_batch_make_inputs_labels` method.".format(type(batch_data)) + ) inputs, labels, *_ = batch_data @@ -154,7 +156,7 @@ def range_test( smooth_f=0.05, diverge_th=5, accumulation_steps=1, - non_blocking_transfer=True + non_blocking_transfer=True, ): """Performs the learning rate range test. @@ -215,10 +217,10 @@ def range_test( redefine method `inputs_labels_from_batch` so that it outputs (inputs, lables) data: >>> import torch_lr_finder >>> class TrainIter(torch_lr_finder.TrainDataLoaderIter): - >>> def batch_make_inputs_labels(self, batch_data): - >>> return (batch_data['user_features'], batch_data['user_history'] ), batch_data['y_labels'] + >>> def inputs_labels_from_batch(self, batch_data): + >>> return (batch_data['user_features'], batch_data['user_history']), batch_data['y_labels'] >>> train_data_iter = TrainIter(train_dl) - >>> finder = torch_lr_finder.LRFinder(model, optimizer, partial(model._train_loss, need_one_hot=False) ) + >>> finder = torch_lr_finder.LRFinder(model, optimizer, partial(model._train_loss, need_one_hot=False)) >>> finder.range_test(train_data_iter, end_lr=10, num_iter=300, diverge_th=10) Reference: @@ -258,9 +260,11 @@ def range_test( elif isinstance(train_loader, TrainDataLoaderIter): train_iter = train_loader else: - raise ValueError("""`train_loader` has unsupported type: {}. - Expected types are `torch.utils.data.DataLoader` - or child of `TrainDataLoaderIter`.""".format(type(train_loader))) + raise ValueError( + "`train_loader` has unsupported type: {}." + "Expected types are `torch.utils.data.DataLoader`" + "or child of `TrainDataLoaderIter`.".format(type(train_loader)) + ) if val_loader: if isinstance(val_loader, DataLoader): @@ -268,9 +272,11 @@ def range_test( elif isinstance(val_loader, ValDataLoaderIter): val_iter = val_loader else: - raise ValueError("""`val_loader` has unsupported type: {}. - Expected types are `torch.utils.data.DataLoader` - or child of `ValDataLoaderIter`.""".format(type(val_loader))) + raise ValueError( + "`val_loader` has unsupported type: {}." + "Expected types are `torch.utils.data.DataLoader`" + "or child of `ValDataLoaderIter`.".format(type(val_loader)) + ) for iteration in tqdm(range(num_iter)): # Train on batch and retrieve loss @@ -322,9 +328,7 @@ def _check_for_scheduler(self): if "initial_lr" in param_group: raise RuntimeError("Optimizer already has a scheduler attached to it") - def _train_batch( - self, train_iter, accumulation_steps, non_blocking_transfer=True - ): + def _train_batch(self, train_iter, accumulation_steps, non_blocking_transfer=True): self.model.train() total_loss = None # for late initialization