Skip to content

Commit

Permalink
Avoid depending on JAX internals (#42)
Browse files Browse the repository at this point in the history
* Avoid depending on JAX internals

* ignore too many arguments lint

* using pylint instead pyproject to disable the warning

* remove deprecated docs theme conf

* test with oryx dev

* bump ci to 3.11

* another try to install oryx

* fix oryx test
  • Loading branch information
fehiepsi authored Oct 24, 2024
1 parent c4a4720 commit 0d49417
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 8 deletions.
9 changes: 4 additions & 5 deletions .github/workflows/pytest_and_autopublish.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ jobs:
# Install deps
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'

- run: sudo apt install -y pandoc gsfonts
- run: pip --version
Expand All @@ -41,14 +41,13 @@ jobs:
# Install deps
- uses: actions/setup-python@v4
with:
python-version: '3.10'
python-version: '3.11'
# Uncomment to cache of pip dependencies (if tests too slow)
# cache: pip
# cache-dependency-path: '**/pyproject.toml'

- run:
pip install -e .[dev]
pip install "git+https://github.com/jax-ml/oryx.git@b59ab020780cd53d488bc7dcad3696be9fdca0a5"
- run: pip install -e .[dev]
- run: pip install "git+https://github.com/jax-ml/oryx.git"

# Run tests (in parallel)
- name: Run core tests
Expand Down
1 change: 1 addition & 0 deletions .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ disable=abstract-method,
too-many-instance-attributes,
too-many-locals,
too-many-nested-blocks,
too-many-positional-arguments,
too-many-public-methods,
too-many-return-statements,
too-many-statements,
Expand Down
1 change: 0 additions & 1 deletion coix/oryx.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,6 @@ def distribution_rule(state, *args, **kwargs):
_, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args)
dist_flat, dist_tree = jax.tree.flatten(dist)
state[name] = {dist_tree: dist_flat}
args = jax.tree.map(jax.core.raise_as_much_as_possible, args)
return random_variable_p.bind(*args, **kwargs), state


Expand Down
2 changes: 1 addition & 1 deletion coix/oryx_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def expected_fn(x):

def test_observed():
def model(a):
return coryx.rv(dist.Delta(2.0, 3.0), obs=1.0, name="x") + a
return coryx.rv(dist.Delta(a, 3.0), obs=1.0, name="x") + a

_, trace, _ = coix.traced_evaluate(model)(2.0)
assert "x" in trace
Expand Down
1 change: 0 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,7 +234,6 @@
# a list of builtin themes.
#
html_theme = "sphinx_rtd_theme"
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]

# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
Expand Down

0 comments on commit 0d49417

Please sign in to comment.