diff --git a/bayeux/_src/optimize/optax.py b/bayeux/_src/optimize/optax.py index 3d7d452..4ccfdc4 100644 --- a/bayeux/_src/optimize/optax.py +++ b/bayeux/_src/optimize/optax.py @@ -18,14 +18,16 @@ from bayeux._src.optimize import shared import jax import optax +import optax.contrib class _OptaxOptimizer(shared.Optimizer): """Base class for optax optimizers.""" + _base = optax def get_kwargs(self, **kwargs): kwargs = self.default_kwargs() | kwargs - optimizer = getattr(optax, self.optimizer) + optimizer = getattr(self._base, self.optimizer) return {optimizer: shared.get_optimizer_kwargs(optimizer, kwargs), "extra_parameters": shared.get_extra_kwargs(kwargs)} @@ -33,7 +35,7 @@ def __call__(self, seed, **kwargs): kwargs = self.get_kwargs(**kwargs) fun, initial_state, apply_transform = self._prep_args(seed, kwargs) - optimizer_fn = getattr(optax, self.optimizer) + optimizer_fn = getattr(self._base, self.optimizer) optimizer = optimizer_fn(**kwargs[optimizer_fn]) num_iters = kwargs["extra_parameters"]["num_iters"] optimizer = functools.partial( @@ -183,6 +185,19 @@ def default_kwargs(self) -> dict[str, float]: return kwargs +class ScheduleFree(_OptaxOptimizer): + _base = optax.contrib + name = "optax_schedule_free" + optimizer = "schedule_free" + + def default_kwargs(self) -> dict[str, float]: + kwargs = super().default_kwargs() + base_optimizer = optax.adam( + **shared.get_optimizer_kwargs(optax.adam, kwargs)) + kwargs["base_optimizer"] = base_optimizer + return kwargs + + class Sgd(_OptaxOptimizer): name = "optax_sgd" optimizer = "sgd" diff --git a/bayeux/optimize/__init__.py b/bayeux/optimize/__init__.py index df1fb5b..890b370 100644 --- a/bayeux/optimize/__init__.py +++ b/bayeux/optimize/__init__.py @@ -44,6 +44,7 @@ # from bayeux._src.optimize.optax import OptimisticGradientDescent # pylint: disable=line-too-long from bayeux._src.optimize.optax import Radam from bayeux._src.optimize.optax import Rmsprop + from bayeux._src.optimize.optax import ScheduleFree from bayeux._src.optimize.optax import Sgd from bayeux._src.optimize.optax import Sm3 from bayeux._src.optimize.optax import Yogi @@ -66,6 +67,7 @@ # "Dpsgd", "Radam", "Rmsprop", + "ScheduleFree", "Sgd", "Sm3", "Yogi",