From b54af4a9abc06f9b7015c9e50fbf3973b51b014f Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 12 Apr 2024 14:18:47 -0400 Subject: [PATCH] Introduce the `chain` argument to propose (#29) * support chain argument in propose * ping oryx to a dev version * new attempt to ping oryx version --- coix/api.py | 9 ++++++--- examples/bmnist.py | 2 +- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/coix/api.py b/coix/api.py index f6f55f5..2d09a4b 100644 --- a/coix/api.py +++ b/coix/api.py @@ -117,7 +117,7 @@ def _fold_in_key(key, i): return key_new.reshape(key.shape) -def propose(p, q, *, loss_fn=None, detach=False): +def propose(p, q, *, loss_fn=None, detach=False, chain=False): """Returns a new program with important weight. We assume the leftmost batch dimension is the particle dimension. You can add @@ -132,19 +132,22 @@ def propose(p, q, *, loss_fn=None, detach=False): q: a proposal program loss_fn: a function that computes loss of this propose combinator detach: whether to detach `value` of the returned program + chain: if True, we will use output of `q` as input of `p` Returns: q_new: the proposed program """ def wrapped(*args, **kwargs): - if util.can_extract_key(args): + if util.can_extract_key(args) and not chain: key_p, key_q = _split_key(args[0]) p_args = (key_p,) + args[1:] q_args = (key_q,) + args[1:] else: p_args = q_args = args - _, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs) + q_out, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs) + if chain: + p_args = q_out metrics = q_metrics.copy() q_latents = { name: util.get_site_value(site) diff --git a/examples/bmnist.py b/examples/bmnist.py index 520721d..d964b0a 100644 --- a/examples/bmnist.py +++ b/examples/bmnist.py @@ -285,7 +285,7 @@ def kernel_what(network, inputs, T=10): # %% -# Finally, we create the dmm inference program, define the loss function, +# Finally, we create the bmnist inference program, define the loss function, # run the training loop, and plot the results. diff --git a/pyproject.toml b/pyproject.toml index 58ef935..ebd65d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,7 @@ doc = [ "sphinx-gallery", ] oryx = [ - "oryx@git+https://github.com/jax-ml/oryx", + "oryx@git+https://github.com/jax-ml/oryx.git@b59ab020780cd53d488bc7dcad3696be9fdca0a5", ] [tool.pyink]