From 66f2a2ffb6fb0f4c421d913ee95c4b156cc1417b Mon Sep 17 00:00:00 2001 From: Ben Zickel <35469979+BenZickel@users.noreply.github.com> Date: Sat, 25 Jan 2025 20:13:59 +0200 Subject: [PATCH] Add a centered variance option to the ClippedAdam optimizer (#3415) * Add option to use centered variance in the ClippedAdam optimizer. * Add test for the centered ClippedAdam optimizer. * Calculate convergence iteration for the centered ClippedAdam optimizer. * Added reference of the centered Adam optimizer. * Add option to use the ClippedAdam optimizer with centered variance in the Latent Dirichlet Allocation example. * Added more detailed comments on ClippedAdam with centered variance and its tests. * Shortened the ClippedAdam centered variance test and added an option to run the full test with plots via a pytest command line option. --- examples/lda.py | 5 +- pyro/optim/clipped_adam.py | 20 ++++-- tests/optim/conftest.py | 10 +++ tests/optim/test_optim.py | 123 +++++++++++++++++++++++++++++++++++++ 4 files changed, 152 insertions(+), 6 deletions(-) diff --git a/examples/lda.py b/examples/lda.py index 16fc09ad0b..00d3ac3bef 100644 --- a/examples/lda.py +++ b/examples/lda.py @@ -137,7 +137,9 @@ def main(args): guide = functools.partial(parametrized_guide, predictor) Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO elbo = Elbo(max_plate_nesting=2) - optim = ClippedAdam({"lr": args.learning_rate}) + optim = ClippedAdam( + {"lr": args.learning_rate, "centered_variance": args.centered_variance} + ) svi = SVI(model, guide, optim, elbo) logging.info("Step\tLoss") for step in range(args.num_steps): @@ -160,6 +162,7 @@ def main(args): parser.add_argument("-n", "--num-steps", default=1000, type=int) parser.add_argument("-l", "--layer-sizes", default="100-100") parser.add_argument("-lr", "--learning-rate", default=0.01, type=float) + parser.add_argument("-cv", "--centered-variance", default=False, type=bool) parser.add_argument("-b", "--batch-size", default=32, type=int) parser.add_argument("--jit", action="store_true") args = parser.parse_args() diff --git a/pyro/optim/clipped_adam.py b/pyro/optim/clipped_adam.py index 14a6a06656..14ac268129 100644 --- a/pyro/optim/clipped_adam.py +++ b/pyro/optim/clipped_adam.py @@ -19,14 +19,21 @@ class ClippedAdam(Optimizer): :param weight_decay: weight decay (L2 penalty) (default: 0) :param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0) :param lrd: rate at which learning rate decays (default: 1.0) + :param centered_variance: use centered variance (default: False) Small modification to the Adam algorithm implemented in torch.optim.Adam - to include gradient clipping and learning rate decay. + to include gradient clipping and learning rate decay and an option to use + the centered variance (see equation 2 in [2]). - Reference + **References** - `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba - https://arxiv.org/abs/1412.6980 + [1] `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba + https://arxiv.org/abs/1412.6980 + + [2] `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`, + Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl, + Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne + https://doi.org/10.3390/computation11050095 """ def __init__( @@ -38,6 +45,7 @@ def __init__( weight_decay=0, clip_norm: float = 10.0, lrd: float = 1.0, + centered_variance: bool = False, ): defaults = dict( lr=lr, @@ -46,6 +54,7 @@ def __init__( weight_decay=weight_decay, clip_norm=clip_norm, lrd=lrd, + centered_variance=centered_variance, ) super().__init__(params, defaults) @@ -87,7 +96,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]: # Decay the first and second moment running average coefficient exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1) - exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2) + grad_var = (grad - exp_avg) if group["centered_variance"] else grad + exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2) denom = exp_avg_sq.sqrt().add_(group["eps"]) diff --git a/tests/optim/conftest.py b/tests/optim/conftest.py index 55dd44d1d5..238deb176e 100644 --- a/tests/optim/conftest.py +++ b/tests/optim/conftest.py @@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items): item.add_marker(pytest.mark.stage("unit")) if "init" not in item.keywords: item.add_marker(pytest.mark.init(rng_seed=123)) + + +def pytest_addoption(parser): + parser.addoption("--plot", action="store", default="FALSE") + + +def pytest_generate_tests(metafunc): + option_value = metafunc.config.option.plot != "FALSE" + if "plot" in metafunc.fixturenames and option_value is not None: + metafunc.parametrize("plot", [option_value]) diff --git a/tests/optim/test_optim.py b/tests/optim/test_optim.py index 6b6dc59d8a..c0acefd7ef 100644 --- a/tests/optim/test_optim.py +++ b/tests/optim/test_optim.py @@ -435,3 +435,126 @@ def step(svi, optimizer): actual.append(step(svi, optimizer)) assert_equal(actual, expected) + + +def test_centered_clipped_adam(plot): + """ + Test the centered variance option of the ClippedAdam optimizer. + In order to create plots run pytest with the plot command line + option set to True, i.e. by executing + + 'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True' + + """ + if not plot: + lr_vec = [0.1, 0.001] + else: + lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001] + + w = torch.Tensor([1, 500]) + + def loss_fn(p): + return (1 + w * p * p).sqrt().sum() - len(w) + + def fit(lr, centered_variance, num_iter=5000): + loss_vec = [] + p = torch.nn.Parameter(torch.Tensor([10, 1])) + optim = pyro.optim.clipped_adam.ClippedAdam( + lr=lr, params=[p], centered_variance=centered_variance + ) + for count in range(num_iter): + optim.zero_grad() + loss = loss_fn(p) + loss.backward() + optim.step() + loss_vec.append(loss) + return torch.Tensor(loss_vec) + + def calc_convergence(loss_vec, tail_len=100, threshold=0.01): + """ + Calculate the number of iterations needed in order to reach the + ultimate loss plus a small threshold, and the convergence rate + which is the mean per iteration improvement of the gap between + the loss and the ultimate loss. + """ + ultimate_loss = loss_vec[-tail_len:].mean() + convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min() + convergence_vec = loss_vec[:convergence_iter] - ultimate_loss + convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean() + return ultimate_loss, convergence_rate, convergence_iter + + def get_convergence_vec(lr_vec, centered_variance): + """ + Fit parameters for a vector of learning rates, with or without centered variance, + and calculate the convergence properties for each learning rate. + """ + ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], [] + for lr in lr_vec: + loss_vec = fit(lr=lr, centered_variance=centered_variance) + ultimate_loss, convergence_rate, convergence_iter = calc_convergence( + loss_vec + ) + ultimate_loss_vec.append(ultimate_loss) + convergence_rate_vec.append(convergence_rate) + convergence_iter_vec.append(convergence_iter) + return ( + torch.Tensor(ultimate_loss_vec), + torch.Tensor(convergence_rate_vec), + convergence_iter_vec, + ) + + ( + centered_ultimate_loss_vec, + centered_convergence_rate_vec, + centered_convergence_iter_vec, + ) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True) + ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec( + lr_vec=lr_vec, centered_variance=False + ) + + # ALl centered variance results should converge + assert (centered_ultimate_loss_vec < 0.01).all() + # Some uncentered variance results do not converge + assert (ultimate_loss_vec > 0.01).any() + # Verify convergence rate improvement + assert ( + (centered_convergence_rate_vec / convergence_rate_vec) + > ((0.12 / torch.Tensor(lr_vec)).log() * 1.08) + ).all() + + if plot: + from matplotlib import pyplot as plt + + plt.figure(figsize=(6, 8)) + plt.subplot(3, 1, 1) + plt.loglog( + lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance" + ) + plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Convergence Iteration") + plt.title("Convergence Iteration vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.subplot(3, 1, 2) + plt.loglog( + lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance" + ) + plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Convergence Rate") + plt.title("Convergence Rate vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.subplot(3, 1, 3) + plt.semilogx( + lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance" + ) + plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance") + plt.xlabel("Learning Rate") + plt.ylabel("Ultimate Loss") + plt.title("Ultimate Loss vs Learning Rate") + plt.grid() + plt.legend(loc="best") + plt.tight_layout() + plt.savefig("test_centered_variance.png")