Skip to content

Commit

Permalink
Add optimistix optimizers.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 595975938
  • Loading branch information
ColCarroll authored and The bayeux Authors committed Jan 5, 2024
1 parent eb6be2c commit 1cf2a7c
Show file tree
Hide file tree
Showing 4 changed files with 150 additions and 3 deletions.
110 changes: 110 additions & 0 deletions bayeux/_src/optimize/optimistix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
# Copyright 2023 The bayeux Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""optimistix specific code."""
from bayeux._src.optimize import shared
import optimistix


class _OptimistixOptimizer(shared.Optimizer):
"""Base class for optimistix optimizers."""

def get_kwargs(self, **kwargs):
kwargs = self.default_kwargs() | kwargs
solver = getattr(optimistix, self.optimizer)
minimise_kwargs = shared.get_optimizer_kwargs(
optimistix.minimise, kwargs, ignore_required={"y0", "solver", "fn"})
for k in minimise_kwargs:
if k in kwargs:
minimise_kwargs[k] = kwargs[k]
extra_parameters = shared.get_extra_kwargs(kwargs)
_ = extra_parameters.pop("num_iters")
return {solver: shared.get_optimizer_kwargs(solver, kwargs),
optimistix.minimise: minimise_kwargs,
"extra_parameters": extra_parameters}

def default_kwargs(self) -> dict[str, float]:
return {"rtol": 1e-5, "atol": 1e-5}

def _prep_args(self, seed, kwargs):
fun, initial_state, apply_transform = super()._prep_args(seed, kwargs)
def f(x, _):
return fun(x)
return f, initial_state, apply_transform

def __call__(self, seed, **kwargs):
kwargs = self.get_kwargs(**kwargs)
fun, initial_state, apply_transform = self._prep_args(seed, kwargs)

solver_fn = getattr(optimistix, self.optimizer)
def run(x0):
solver = solver_fn(**kwargs[solver_fn])
return optimistix.minimise(
fn=fun,
solver=solver,
y0=x0,
**kwargs[optimistix.minimise]).value
chain_method = kwargs["extra_parameters"]["chain_method"]
mapped_run = self._map_optimizer(chain_method, run)
ret = mapped_run(initial_state)
if apply_transform:
return shared.OptimizerResults(
params=self.transform_fn(ret), state=None, loss=None)
else:
return shared.OptimizerResults(ret, state=None, loss=None)


class BFGS(_OptimistixOptimizer):
name = "optimistix_bfgs"
optimizer = "BFGS"


class Chord(_OptimistixOptimizer):
name = "optimistix_chord"
optimizer = "Chord"


class Dogleg(_OptimistixOptimizer):
name = "optimistix_dogleg"
optimizer = "Dogleg"


class GaussNewton(_OptimistixOptimizer):
name = "optimistix_gauss_newton"
optimizer = "GaussNewton"


class IndirectLevenbergMarquardt(_OptimistixOptimizer):
name = "optimistix_indirect_levenberg_marquardt"
optimizer = "IndirectLevenbergMarquardt"


class LevenbergMarquardt(_OptimistixOptimizer):
name = "optimistix_levenberg_marquardt"
optimizer = "LevenbergMarquardt"


class NelderMead(_OptimistixOptimizer):
name = "optimistix_nelder_mead"
optimizer = "NelderMead"


class Newton(_OptimistixOptimizer):
name = "optimistix_newton"
optimizer = "Newton"


class NonlinearCG(_OptimistixOptimizer):
name = "optimistix_nonlinear_cg"
optimizer = "NonlinearCG"
22 changes: 22 additions & 0 deletions bayeux/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@
from bayeux._src.optimize.jaxopt import NonlinearCG
__all__.extend(["BFGS", "GradientDescent", "LBFGS", "NonlinearCG"])

if importlib.util.find_spec("optimistix") is not None:
from bayeux._src.optimize.optimistix import BFGS as optimistix_BFGS
from bayeux._src.optimize.optimistix import Chord
from bayeux._src.optimize.optimistix import Dogleg
from bayeux._src.optimize.optimistix import GaussNewton
from bayeux._src.optimize.optimistix import IndirectLevenbergMarquardt
from bayeux._src.optimize.optimistix import LevenbergMarquardt
from bayeux._src.optimize.optimistix import NelderMead
from bayeux._src.optimize.optimistix import Newton
from bayeux._src.optimize.optimistix import NonlinearCG as optimistix_NonlinearCG

__all__.extend([
"optimistix_BFGS",
"Chord",
"Dogleg",
"GaussNewton",
"IndirectLevenbergMarquardt",
"LevenbergMarquardt",
"NelderMead",
"Newton",
"optimistix_NonlinearCG"])

if importlib.util.find_spec("optax") is not None:
from bayeux._src.optimize.optax import AdaBelief
from bayeux._src.optimize.optax import Adafactor
Expand Down
20 changes: 17 additions & 3 deletions bayeux/tests/optimize_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,14 +67,28 @@ def test_optimizers(method, linear_model): # pylint: disable=redefined-outer-na
else:
num_iters = 1_000

if method.startswith("optimistix"):
num_iters = 10_000 # should stop automatically before then
atol = 0.2
else:
atol = 1e-2

assert optimizer.debug(seed=seed, verbosity=0)
num_particles = 6
params = optimizer(
seed=seed, num_particles=num_particles, num_iters=num_iters).params
seed=seed,
num_particles=num_particles,
num_iters=num_iters,
atol=atol,
max_steps=num_iters,
throw=False).params
expected = np.repeat(solution[..., np.newaxis], num_particles, axis=-1).T

if method != "optax_adafactor":
np.testing.assert_allclose(expected, params.w, atol=1e-2)
if method not in {
"optax_adafactor",
"optimistix_chord",
"optimistix_nelder_mead"}:
np.testing.assert_allclose(expected, params.w, atol=atol)


def test_initial_state():
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ dependencies = [
"oryx>=0.2.5",
"arviz",
"optax",
"optimistix",
"blackjax",
"numpyro",
"jaxopt",
Expand Down

0 comments on commit 1cf2a7c

Please sign in to comment.