-
Notifications
You must be signed in to change notification settings - Fork 122
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* TST: add some tests for `LRFinder` This is a draft of unit tests for this package. For details of how test cases are written, please check out "tests/README.md". * TST: replace env vars with command line arguments for pytest runner Other requested changes mentioned in PR #27 are also done in this commit. * TST: remove decorator for making metaclass work on Py2k * TST: remove local import statements Local imports in `collect_task_classes()` is not necessary since module `task` has been imported in global. * STY: format code with black
- Loading branch information
1 parent
66e23d7
commit 34e2380
Showing
8 changed files
with
365 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,3 +2,4 @@ matplotlib | |
numpy | ||
torch>=0.4.1 | ||
tqdm | ||
pytest |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
## Requirements | ||
- pytest | ||
|
||
## Run tests | ||
- normal (use GPU if it's available) | ||
```bash | ||
# in root directory of this package | ||
$ python -mpytest ./tests | ||
``` | ||
|
||
- forcibly run all tests on CPU | ||
```bash | ||
# in root directory of this package | ||
$ python -mpytest --cpu_only ./tests | ||
``` | ||
|
||
## How to add new test cases | ||
To make it able to create test cases and re-use settings conveniently, here we package those basic elements for running a training task into objects inheriting `BaseTask` in `task.py`. | ||
|
||
A `BaseTask` is formed of these members: | ||
- `batch_size` | ||
- `model` | ||
- `optimizer` | ||
- `criterion` (loss function) | ||
- `device` (`cpu`, `cuda`, etc.) | ||
- `train_loader` (`torch.utils.data.DataLoader` for training set) | ||
- `val_loader` (`torch.utils.data.DataLoader` for validation set) | ||
|
||
If you want to create a new task, just write a new class inheriting `BaseTask` and add your configuration in `__init__`. | ||
|
||
Note-1: Any task inheriting `BaseTask` in `task.py` will be collected by the function `test_lr_finder.py::collect_task_classes()`. | ||
|
||
Note-2: Model and dataset will be instantiated when a task class is **initialized**, so that it is not recommended to collect a lot of task **objects** at once. | ||
|
||
|
||
### Directly use specific task in a test case | ||
```python | ||
from . import task as mod_task | ||
def test_run(): | ||
task = mod_task.FooTask() | ||
... | ||
``` | ||
|
||
### Use `pytest.mark.parametrize` | ||
- Use specified task in a test case | ||
```python | ||
@pytest.mark.parametrize( | ||
'cls_task, arg', # names of parameters (see also the signature of the following function) | ||
[ | ||
(task.FooTask, 'foo'), | ||
(task.BarTask, 'bar'), | ||
], # list of parameters | ||
) | ||
def test_run(cls_task, arg): | ||
... | ||
``` | ||
|
||
- Use all existing tasks in a test case | ||
```python | ||
@pytest.mark.parametrize( | ||
'cls_task', | ||
collect_task_classes(), | ||
) | ||
def test_run(cls_task): | ||
... | ||
``` |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,41 @@ | ||
import pytest | ||
|
||
|
||
class CustomCommandLineOption(object): | ||
"""An object for storing command line options parsed by pytest. | ||
Since `pytest.config` global object is deprecated and removed in version | ||
5.0, this class is made to work as a store of command line options for | ||
those components which are not able to access them via `request.config`. | ||
""" | ||
|
||
def __init__(self): | ||
self._content = {} | ||
|
||
def __str__(self): | ||
return str(self._content) | ||
|
||
def add(self, key, value): | ||
self._content.update({key: value}) | ||
|
||
def delete(self, key): | ||
del self._content[key] | ||
|
||
def __getattr__(self, key): | ||
if key in self._content: | ||
return self._content[key] | ||
else: | ||
return super(CustomCommandLineOption, self).__getattr__(key) | ||
|
||
|
||
def pytest_addoption(parser): | ||
parser.addoption( | ||
"--cpu_only", action="store_true", help="Forcibly run all tests on CPU." | ||
) | ||
|
||
|
||
def pytest_configure(config): | ||
# Bind a config object to `pytest` module instance | ||
pytest.custom_cmdopt = CustomCommandLineOption() | ||
|
||
pytest.custom_cmdopt.add("cpu_only", config.getoption("--cpu_only")) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
import numpy as np | ||
import torch | ||
from torch.utils.data import Dataset | ||
|
||
|
||
class XORDataset(Dataset): | ||
def __init__(self, length, shape=None): | ||
""" | ||
Arguments: | ||
length (int): length of dataset, which equals `len(self)`. | ||
shape (list, tuple, optional): shape of dataset. If it isn't | ||
specified, it will be initialized to `(length, 8)`. | ||
Default: None. | ||
""" | ||
_shape = (length,) + tuple(shape) if shape else (length, 8) | ||
raw = np.random.randint(0, 2, _shape) | ||
self.data = torch.FloatTensor(raw) | ||
self.label = ( | ||
torch.tensor(np.bitwise_xor.reduce(raw, axis=1)).unsqueeze(dim=1).float() | ||
) | ||
|
||
def __getitem__(self, index): | ||
return self.data[index], self.label[index] | ||
|
||
def __len__(self): | ||
return len(self.data) | ||
|
||
|
||
class ExtraXORDataset(XORDataset): | ||
""" A XOR dataset which is able to return extra values. """ | ||
|
||
def __init__(self, length, shape=None, extra_dims=1): | ||
""" | ||
Arguments: | ||
length (int): length of dataset, which equals `len(self)`. | ||
shape (list, tuple, optional): shape of dataset. If it isn't | ||
specified, it will be initialized to `(length, 8)`. | ||
Default: None. | ||
extra_dims (int, optional): dimension of extra values. | ||
Default: 1. | ||
""" | ||
super(ExtraXORDataset, self).__init__(length, shape=shape) | ||
if extra_dims: | ||
_extra_shape = (length, extra_dims) | ||
self.extras = torch.randint(0, 2, _extra_shape) | ||
else: | ||
self.extras = None | ||
|
||
def __getitem__(self, index): | ||
if self.extras is not None: | ||
retval = [self.data[index], self.label[index]] | ||
retval.extend([v for v in self.extras[index]]) | ||
return retval | ||
else: | ||
return self.data[index], self.label[index] | ||
|
||
def __len__(self): | ||
return len(self.data) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,15 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
|
||
|
||
class LinearMLP(nn.Module): | ||
def __init__(self, layer_dim): | ||
super(LinearMLP, self).__init__() | ||
io_pairs = zip(layer_dim[:-1], layer_dim[1:]) | ||
layers = [nn.Linear(idim, odim) for idim, odim in io_pairs] | ||
self.net = nn.Sequential(*layers) | ||
|
||
def forward(self, x): | ||
return self.net(x) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,112 @@ | ||
import torch | ||
import torch.nn as nn | ||
import torch.nn.functional as F | ||
import torch.optim as optim | ||
from torch.utils.data import DataLoader, Subset | ||
import pytest | ||
|
||
from .model import LinearMLP | ||
from .dataset import XORDataset, ExtraXORDataset | ||
|
||
|
||
def use_cuda(): | ||
if pytest.custom_cmdopt.cpu_only: | ||
return False | ||
else: | ||
return torch.cuda.is_available() | ||
|
||
|
||
class TaskTemplate(type): | ||
def __call__(cls, *args, **kwargs): | ||
obj = type.__call__(cls, *args, **kwargs) | ||
if hasattr(obj, "__post_init__"): | ||
obj.__post_init__() | ||
return obj | ||
|
||
|
||
class BaseTask(metaclass=TaskTemplate): | ||
def __init__(self): | ||
self.batch_size = -1 | ||
self.model = None | ||
self.optimizer = None | ||
self.criterion = None | ||
self.device = None | ||
self.train_loader = None | ||
self.val_loader = None | ||
|
||
def __post_init__(self): | ||
# Check whether cuda is available or not, and we will cast `self.device` | ||
# to `torch.device` here to make sure operations related to moving tensor | ||
# would work fine later. | ||
if not use_cuda(): | ||
self.device = None | ||
if self.device is None: | ||
return | ||
|
||
if isinstance(self.device, str): | ||
self.device = torch.device(self.device) | ||
elif not isinstance(self.device, torch.device): | ||
raise TypeError("Invalid type of device.") | ||
|
||
|
||
class XORTask(BaseTask): | ||
def __init__(self, validate=False): | ||
super(XORTask, self).__init__() | ||
bs, steps = 8, 64 | ||
dataset = XORDataset(bs * steps) | ||
if validate: | ||
self.train_loader = DataLoader(Subset(dataset, range(steps - bs))) | ||
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps))) | ||
else: | ||
self.train_loader = DataLoader(dataset) | ||
self.val_loader = None | ||
|
||
self.batch_size = bs | ||
self.model = LinearMLP([8, 4, 1]) | ||
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3) | ||
self.criterion = nn.MSELoss() | ||
self.device = torch.device("cuda") | ||
|
||
|
||
class ExtraXORTask(BaseTask): | ||
def __init__(self, validate=False): | ||
super(ExtraXORTask, self).__init__() | ||
bs, steps = 8, 64 | ||
dataset = ExtraXORDataset(bs * steps, extra_dims=2) | ||
if validate: | ||
self.train_loader = DataLoader(Subset(dataset, range(steps - bs))) | ||
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps))) | ||
else: | ||
self.train_loader = DataLoader(dataset) | ||
self.val_loader = None | ||
|
||
self.model = LinearMLP([8, 4, 1]) | ||
self.optimizer = optim.SGD(self.model.parameters(), lr=1e-3) | ||
self.criterion = nn.MSELoss() | ||
self.device = torch.device("cuda") | ||
|
||
|
||
class DiscriminativeLearningRateTask(BaseTask): | ||
def __init__(self, validate=False): | ||
super(DiscriminativeLearningRateTask, self).__init__() | ||
bs, steps = 8, 64 | ||
dataset = XORDataset(bs * steps) | ||
if validate: | ||
self.train_loader = DataLoader(Subset(dataset, range(steps - bs))) | ||
self.val_loader = DataLoader(Subset(dataset, range(steps - bs, steps))) | ||
else: | ||
self.train_loader = DataLoader(dataset) | ||
self.val_loader = None | ||
|
||
dataset = XORDataset(128) | ||
self.model = LinearMLP([8, 4, 1]) | ||
self.optimizer = optim.SGD( | ||
[ | ||
{"params": self.model.net[0].parameters(), "lr": 0.01}, | ||
{"params": self.model.net[1].parameters(), "lr": 0.001}, | ||
], | ||
lr=1e-3, | ||
momentum=0.5, | ||
) | ||
self.criterion = nn.MSELoss() | ||
self.device = torch.device("cuda") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,72 @@ | ||
import pytest | ||
from torch_lr_finder import LRFinder | ||
|
||
from . import task as mod_task | ||
|
||
|
||
def collect_task_classes(): | ||
names = [v for v in dir(mod_task) if v.endswith("Task") and v != "BaseTask"] | ||
attrs = [getattr(mod_task, v) for v in names] | ||
classes = [v for v in attrs if issubclass(v, mod_task.BaseTask)] | ||
return classes | ||
|
||
|
||
def prepare_lr_finder(task, **kwargs): | ||
model = task.model | ||
optimizer = task.optimizer | ||
criterion = task.criterion | ||
config = { | ||
"device": kwargs.get("device", None), | ||
"memory_cache": kwargs.get("memory_cache", True), | ||
"cache_dir": kwargs.get("cache_dir", None), | ||
} | ||
lr_finder = LRFinder(model, optimizer, criterion, **config) | ||
return lr_finder | ||
|
||
|
||
def get_optim_lr(optimizer): | ||
return [grp["lr"] for grp in optimizer.param_groups] | ||
|
||
|
||
class TestRangeTest: | ||
@pytest.mark.parametrize("cls_task", collect_task_classes()) | ||
def test_run(self, cls_task): | ||
task = cls_task() | ||
init_lrs = get_optim_lr(task.optimizer) | ||
|
||
lr_finder = prepare_lr_finder(task) | ||
lr_finder.range_test(task.train_loader) | ||
|
||
# check whether lr is actually changed | ||
assert max(lr_finder.history["lr"]) >= init_lrs[0] | ||
|
||
@pytest.mark.parametrize("cls_task", collect_task_classes()) | ||
def test_run_with_val_loader(self, cls_task): | ||
task = cls_task(validate=True) | ||
init_lrs = get_optim_lr(task.optimizer) | ||
|
||
lr_finder = prepare_lr_finder(task) | ||
lr_finder.range_test(task.train_loader, val_loader=task.val_loader) | ||
|
||
# check whether lr is actually changed | ||
assert max(lr_finder.history["lr"]) >= init_lrs[0] | ||
|
||
|
||
class TestReset: | ||
@pytest.mark.parametrize( | ||
"cls_task", | ||
[ | ||
mod_task.XORTask, | ||
mod_task.DiscriminativeLearningRateTask, | ||
], | ||
) | ||
def test_reset(self, cls_task): | ||
task = cls_task() | ||
init_lrs = get_optim_lr(task.optimizer) | ||
|
||
lr_finder = prepare_lr_finder(task) | ||
lr_finder.range_test(task.train_loader, val_loader=task.val_loader) | ||
lr_finder.reset() | ||
|
||
restored_lrs = get_optim_lr(task.optimizer) | ||
assert init_lrs == restored_lrs |