diff --git a/.github/workflows/pytest_and_autopublish.yml b/.github/workflows/pytest_and_autopublish.yml index 986a296..cd4c7ee 100644 --- a/.github/workflows/pytest_and_autopublish.yml +++ b/.github/workflows/pytest_and_autopublish.yml @@ -14,7 +14,7 @@ jobs: # Install deps - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' - run: sudo apt install -y pandoc gsfonts - run: pip --version @@ -41,14 +41,13 @@ jobs: # Install deps - uses: actions/setup-python@v4 with: - python-version: '3.10' + python-version: '3.11' # Uncomment to cache of pip dependencies (if tests too slow) # cache: pip # cache-dependency-path: '**/pyproject.toml' - - run: - pip install -e .[dev] - pip install "git+https://github.com/jax-ml/oryx.git@b59ab020780cd53d488bc7dcad3696be9fdca0a5" + - run: pip install -e .[dev] + - run: pip install "git+https://github.com/jax-ml/oryx.git" # Run tests (in parallel) - name: Run core tests diff --git a/.pylintrc b/.pylintrc index bac9a2e..2ce4dd1 100644 --- a/.pylintrc +++ b/.pylintrc @@ -137,6 +137,7 @@ disable=abstract-method, too-many-instance-attributes, too-many-locals, too-many-nested-blocks, + too-many-positional-arguments, too-many-public-methods, too-many-return-statements, too-many-statements, diff --git a/coix/oryx.py b/coix/oryx.py index 9636907..ba622bb 100644 --- a/coix/oryx.py +++ b/coix/oryx.py @@ -250,7 +250,6 @@ def distribution_rule(state, *args, **kwargs): _, dist = jax.tree.unflatten(kwargs["in_tree"], flat_args) dist_flat, dist_tree = jax.tree.flatten(dist) state[name] = {dist_tree: dist_flat} - args = jax.tree.map(jax.core.raise_as_much_as_possible, args) return random_variable_p.bind(*args, **kwargs), state diff --git a/coix/oryx_test.py b/coix/oryx_test.py index 65a490b..6df092e 100644 --- a/coix/oryx_test.py +++ b/coix/oryx_test.py @@ -113,7 +113,7 @@ def expected_fn(x): def test_observed(): def model(a): - return coryx.rv(dist.Delta(2.0, 3.0), obs=1.0, name="x") + a + return coryx.rv(dist.Delta(a, 3.0), obs=1.0, name="x") + a _, trace, _ = coix.traced_evaluate(model)(2.0) assert "x" in trace diff --git a/docs/conf.py b/docs/conf.py index ee9077b..15ebae7 100644 --- a/docs/conf.py +++ b/docs/conf.py @@ -234,7 +234,6 @@ # a list of builtin themes. # html_theme = "sphinx_rtd_theme" -html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the