Skip to content

Commit

Permalink
Introduce the chain argument to propose (#29)
Browse files Browse the repository at this point in the history
* support chain argument in propose

* ping oryx to a dev version

* new attempt to ping oryx version
  • Loading branch information
fehiepsi authored Apr 12, 2024
1 parent f00cbfd commit b54af4a
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 5 deletions.
9 changes: 6 additions & 3 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ def _fold_in_key(key, i):
return key_new.reshape(key.shape)


def propose(p, q, *, loss_fn=None, detach=False):
def propose(p, q, *, loss_fn=None, detach=False, chain=False):
"""Returns a new program with important weight.
We assume the leftmost batch dimension is the particle dimension. You can add
Expand All @@ -132,19 +132,22 @@ def propose(p, q, *, loss_fn=None, detach=False):
q: a proposal program
loss_fn: a function that computes loss of this propose combinator
detach: whether to detach `value` of the returned program
chain: if True, we will use output of `q` as input of `p`
Returns:
q_new: the proposed program
"""

def wrapped(*args, **kwargs):
if util.can_extract_key(args):
if util.can_extract_key(args) and not chain:
key_p, key_q = _split_key(args[0])
p_args = (key_p,) + args[1:]
q_args = (key_q,) + args[1:]
else:
p_args = q_args = args
_, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs)
q_out, q_trace, q_metrics = core.traced_evaluate(q)(*q_args, **kwargs)
if chain:
p_args = q_out
metrics = q_metrics.copy()
q_latents = {
name: util.get_site_value(site)
Expand Down
2 changes: 1 addition & 1 deletion examples/bmnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,7 @@ def kernel_what(network, inputs, T=10):


# %%
# Finally, we create the dmm inference program, define the loss function,
# Finally, we create the bmnist inference program, define the loss function,
# run the training loop, and plot the results.


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ doc = [
"sphinx-gallery",
]
oryx = [
"oryx@git+https://github.com/jax-ml/oryx",
"oryx@git+https://github.com/jax-ml/oryx.git@b59ab020780cd53d488bc7dcad3696be9fdca0a5",
]

[tool.pyink]
Expand Down

0 comments on commit b54af4a

Please sign in to comment.