Skip to content

Commit

Permalink
Update oryx version to 0.2.7 and only expose optax_schedule_free wh…
Browse files Browse the repository at this point in the history
…en it is available.

Fixes #47

PiperOrigin-RevId: 636641008
  • Loading branch information
ColCarroll authored and The bayeux Authors committed May 23, 2024
1 parent cb27471 commit 21f903d
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
7 changes: 4 additions & 3 deletions bayeux/optimize/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"""Imports from submodules."""
# pylint: disable=g-importing-member
# pylint: disable=g-import-not-at-top
import importlib
import importlib.util

__all__ = []

Expand Down Expand Up @@ -44,7 +44,9 @@
# from bayeux._src.optimize.optax import OptimisticGradientDescent # pylint: disable=line-too-long
from bayeux._src.optimize.optax import Radam
from bayeux._src.optimize.optax import Rmsprop
from bayeux._src.optimize.optax import ScheduleFree
if importlib.util.find_spec("optax.contrib._schedule_free") is not None:
from bayeux._src.optimize.optax import ScheduleFree
__all__.append("ScheduleFree")
from bayeux._src.optimize.optax import Sgd
from bayeux._src.optimize.optax import Sm3
from bayeux._src.optimize.optax import Yogi
Expand All @@ -67,7 +69,6 @@
# "Dpsgd",
"Radam",
"Rmsprop",
"ScheduleFree",
"Sgd",
"Sm3",
"Yogi",
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ keywords = []
dependencies = [
"jax>=0.4.6",
"tensorflow-probability[jax]>=0.19.0",
"oryx>=0.2.5",
"oryx>=0.2.7",
"arviz",
"optax",
"optimistix",
Expand Down

0 comments on commit 21f903d

Please sign in to comment.