Skip to content

Commit

Permalink
Add a centered variance option to the ClippedAdam optimizer (#3415)
Browse files Browse the repository at this point in the history
* Add option to use centered variance in the ClippedAdam optimizer.

* Add test for the centered ClippedAdam optimizer.

* Calculate convergence iteration for the centered ClippedAdam optimizer.

* Added reference of the centered Adam optimizer.

* Add option to use the ClippedAdam optimizer with centered variance in the Latent Dirichlet Allocation example.

* Added more detailed comments on ClippedAdam with centered variance and its tests.

* Shortened the ClippedAdam centered variance test and added an option to run the full test with plots via a pytest command line option.
  • Loading branch information
BenZickel authored Jan 25, 2025
1 parent 86277f3 commit 66f2a2f
Show file tree
Hide file tree
Showing 4 changed files with 152 additions and 6 deletions.
5 changes: 4 additions & 1 deletion examples/lda.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,9 @@ def main(args):
guide = functools.partial(parametrized_guide, predictor)
Elbo = JitTraceEnum_ELBO if args.jit else TraceEnum_ELBO
elbo = Elbo(max_plate_nesting=2)
optim = ClippedAdam({"lr": args.learning_rate})
optim = ClippedAdam(
{"lr": args.learning_rate, "centered_variance": args.centered_variance}
)
svi = SVI(model, guide, optim, elbo)
logging.info("Step\tLoss")
for step in range(args.num_steps):
Expand All @@ -160,6 +162,7 @@ def main(args):
parser.add_argument("-n", "--num-steps", default=1000, type=int)
parser.add_argument("-l", "--layer-sizes", default="100-100")
parser.add_argument("-lr", "--learning-rate", default=0.01, type=float)
parser.add_argument("-cv", "--centered-variance", default=False, type=bool)
parser.add_argument("-b", "--batch-size", default=32, type=int)
parser.add_argument("--jit", action="store_true")
args = parser.parse_args()
Expand Down
20 changes: 15 additions & 5 deletions pyro/optim/clipped_adam.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,21 @@ class ClippedAdam(Optimizer):
:param weight_decay: weight decay (L2 penalty) (default: 0)
:param clip_norm: magnitude of norm to which gradients are clipped (default: 10.0)
:param lrd: rate at which learning rate decays (default: 1.0)
:param centered_variance: use centered variance (default: False)
Small modification to the Adam algorithm implemented in torch.optim.Adam
to include gradient clipping and learning rate decay.
to include gradient clipping and learning rate decay and an option to use
the centered variance (see equation 2 in [2]).
Reference
**References**
`A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980
[1] `A Method for Stochastic Optimization`, Diederik P. Kingma, Jimmy Ba
https://arxiv.org/abs/1412.6980
[2] `A Two-Step Machine Learning Method for Predicting the Formation Energy of Ternary Compounds`,
Varadarajan Rengaraj, Sebastian Jost, Franz Bethke, Christian Plessl,
Hossein Mirhosseini, Andrea Walther, Thomas D. Kühne
https://doi.org/10.3390/computation11050095
"""

def __init__(
Expand All @@ -38,6 +45,7 @@ def __init__(
weight_decay=0,
clip_norm: float = 10.0,
lrd: float = 1.0,
centered_variance: bool = False,
):
defaults = dict(
lr=lr,
Expand All @@ -46,6 +54,7 @@ def __init__(
weight_decay=weight_decay,
clip_norm=clip_norm,
lrd=lrd,
centered_variance=centered_variance,
)
super().__init__(params, defaults)

Expand Down Expand Up @@ -87,7 +96,8 @@ def step(self, closure: Optional[Callable] = None) -> Optional[Any]:

# Decay the first and second moment running average coefficient
exp_avg.mul_(beta1).add_(grad, alpha=1 - beta1)
exp_avg_sq.mul_(beta2).addcmul_(grad, grad, value=1 - beta2)
grad_var = (grad - exp_avg) if group["centered_variance"] else grad
exp_avg_sq.mul_(beta2).addcmul_(grad_var, grad_var, value=1 - beta2)

denom = exp_avg_sq.sqrt().add_(group["eps"])

Expand Down
10 changes: 10 additions & 0 deletions tests/optim/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,3 +11,13 @@ def pytest_collection_modifyitems(items):
item.add_marker(pytest.mark.stage("unit"))
if "init" not in item.keywords:
item.add_marker(pytest.mark.init(rng_seed=123))


def pytest_addoption(parser):
parser.addoption("--plot", action="store", default="FALSE")


def pytest_generate_tests(metafunc):
option_value = metafunc.config.option.plot != "FALSE"
if "plot" in metafunc.fixturenames and option_value is not None:
metafunc.parametrize("plot", [option_value])
123 changes: 123 additions & 0 deletions tests/optim/test_optim.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,3 +435,126 @@ def step(svi, optimizer):
actual.append(step(svi, optimizer))

assert_equal(actual, expected)


def test_centered_clipped_adam(plot):
"""
Test the centered variance option of the ClippedAdam optimizer.
In order to create plots run pytest with the plot command line
option set to True, i.e. by executing
'pytest tests/optim/test_optim.py::test_centered_clipped_adam --plot True'
"""
if not plot:
lr_vec = [0.1, 0.001]
else:
lr_vec = [0.1, 0.05, 0.02, 0.01, 0.005, 0.002, 0.001]

w = torch.Tensor([1, 500])

def loss_fn(p):
return (1 + w * p * p).sqrt().sum() - len(w)

def fit(lr, centered_variance, num_iter=5000):
loss_vec = []
p = torch.nn.Parameter(torch.Tensor([10, 1]))
optim = pyro.optim.clipped_adam.ClippedAdam(
lr=lr, params=[p], centered_variance=centered_variance
)
for count in range(num_iter):
optim.zero_grad()
loss = loss_fn(p)
loss.backward()
optim.step()
loss_vec.append(loss)
return torch.Tensor(loss_vec)

def calc_convergence(loss_vec, tail_len=100, threshold=0.01):
"""
Calculate the number of iterations needed in order to reach the
ultimate loss plus a small threshold, and the convergence rate
which is the mean per iteration improvement of the gap between
the loss and the ultimate loss.
"""
ultimate_loss = loss_vec[-tail_len:].mean()
convergence_iter = (loss_vec < (ultimate_loss + threshold)).nonzero().min()
convergence_vec = loss_vec[:convergence_iter] - ultimate_loss
convergence_rate = (convergence_vec[:-1] / convergence_vec[1:]).log().mean()
return ultimate_loss, convergence_rate, convergence_iter

def get_convergence_vec(lr_vec, centered_variance):
"""
Fit parameters for a vector of learning rates, with or without centered variance,
and calculate the convergence properties for each learning rate.
"""
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = [], [], []
for lr in lr_vec:
loss_vec = fit(lr=lr, centered_variance=centered_variance)
ultimate_loss, convergence_rate, convergence_iter = calc_convergence(
loss_vec
)
ultimate_loss_vec.append(ultimate_loss)
convergence_rate_vec.append(convergence_rate)
convergence_iter_vec.append(convergence_iter)
return (
torch.Tensor(ultimate_loss_vec),
torch.Tensor(convergence_rate_vec),
convergence_iter_vec,
)

(
centered_ultimate_loss_vec,
centered_convergence_rate_vec,
centered_convergence_iter_vec,
) = get_convergence_vec(lr_vec=lr_vec, centered_variance=True)
ultimate_loss_vec, convergence_rate_vec, convergence_iter_vec = get_convergence_vec(
lr_vec=lr_vec, centered_variance=False
)

# ALl centered variance results should converge
assert (centered_ultimate_loss_vec < 0.01).all()
# Some uncentered variance results do not converge
assert (ultimate_loss_vec > 0.01).any()
# Verify convergence rate improvement
assert (
(centered_convergence_rate_vec / convergence_rate_vec)
> ((0.12 / torch.Tensor(lr_vec)).log() * 1.08)
).all()

if plot:
from matplotlib import pyplot as plt

plt.figure(figsize=(6, 8))
plt.subplot(3, 1, 1)
plt.loglog(
lr_vec, centered_convergence_iter_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_iter_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Iteration")
plt.title("Convergence Iteration vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 2)
plt.loglog(
lr_vec, centered_convergence_rate_vec, "b.-", label="Centered Variance"
)
plt.loglog(lr_vec, convergence_rate_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Convergence Rate")
plt.title("Convergence Rate vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.subplot(3, 1, 3)
plt.semilogx(
lr_vec, centered_ultimate_loss_vec, "b.-", label="Centered Variance"
)
plt.semilogx(lr_vec, ultimate_loss_vec, "r.-", label="Uncentered Variance")
plt.xlabel("Learning Rate")
plt.ylabel("Ultimate Loss")
plt.title("Ultimate Loss vs Learning Rate")
plt.grid()
plt.legend(loc="best")
plt.tight_layout()
plt.savefig("test_centered_variance.png")

0 comments on commit 66f2a2f

Please sign in to comment.