From 171b0b4d2eecbfe4686b22c24ea81aacfbf0161e Mon Sep 17 00:00:00 2001 From: Du Phan Date: Fri, 12 Apr 2024 14:26:28 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 624292817 --- coix/__init__.py | 2 +- coix/algo.py | 128 +++++++++++++++++++++++----------------------- coix/api.py | 7 ++- coix/core.py | 2 - coix/loss.py | 44 ++++++++-------- coix/loss_test.py | 2 +- coix/util.py | 26 ++++++++++ 7 files changed, 118 insertions(+), 93 deletions(-) diff --git a/coix/__init__.py b/coix/__init__.py index f52e881..7188797 100644 --- a/coix/__init__.py +++ b/coix/__init__.py @@ -55,4 +55,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = "0.1.0" +__version__ = "0.0.1" diff --git a/coix/algo.py b/coix/algo.py index 2da2eb2..f600fbe 100644 --- a/coix/algo.py +++ b/coix/algo.py @@ -65,9 +65,9 @@ def _use_fori_loop(targets, num_targets, *fns): def aft(targets, flows, *, num_targets=None): """Annealed Flow Transport. - 1. *Annealed Flow Transport Monte Carlo*, - Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet - https://arxiv.org/abs/2102.07501 + [1] Annealed Flow Transport Monte Carlo, + Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet + https://arxiv.org/abs/2102.07501 Args: targets: a list of target programs @@ -94,10 +94,10 @@ def body_fun(i, q): def apgs(target, kernels, *, num_sweeps=1): """Amortized Population Gibbs Sampler. - 1. *Amortized Population Gibbs Samplers with Neural Sufficient Statistics*, - Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de - Meent - https://arxiv.org/abs/1911.01382 + [1] Amortized Population Gibbs Samplers with Neural Sufficient Statistics, + Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de + Meent + https://arxiv.org/abs/1911.01382 Args: target: the target program @@ -123,13 +123,13 @@ def body_fn(_, q): def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None): """Differentiable Annealed Importance Sampling. - 1. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, - Tomas Geffner, Justin Domke - https://arxiv.org/abs/2107.04150 - 2. *Differentiable Annealed Importance Sampling and the Perils of Gradient - Noise*, - Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse - https://arxiv.org/abs/2107.10211 + [1] MCMC Variational Inference via Uncorrected Hamiltonian Annealing, + Tomas Geffner, Justin Domke + https://arxiv.org/abs/2107.04150 + [2] Differentiable Annealed Importance Sampling and the Perils of Gradient + Noise, + Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse + https://arxiv.org/abs/2107.10211 Args: targets: a list of target programs @@ -166,9 +166,9 @@ def body_fun(i, q): def nasmc(targets, proposals, *, num_targets=None): """Neural Adaptive Sequential Monte Carlo. - 1. *Neural Adaptive Sequential Monte Carlo*, - Shixiang Gu, Zoubin Ghahramani, Richard E. Turner - https://arxiv.org/abs/1506.03338 + [1] Neural Adaptive Sequential Monte Carlo, + Shixiang Gu, Zoubin Ghahramani, Richard E. Turner + https://arxiv.org/abs/1506.03338 Args: targets: a list of target programs @@ -196,10 +196,10 @@ def body_fun(i, q): def nvi_avo(targets, forwards, reverses, *, num_targets=None): """AIS with Annealed Variational Objective. - 1. *Improving Explorability in Variational Inference with Annealed Variational - Objectives*, - Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville - https://arxiv.org/abs/1809.01818 + [1] Improving Explorability in Variational Inference with Annealed Variational + Objectives, + Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville + https://arxiv.org/abs/1809.01818 Args: targets: a list of target programs @@ -231,9 +231,9 @@ def nvi_fkl(targets, proposals, *, num_targets=None): This is different from `nasmc`, where we assume that the targets are filtering distributions. We also assume that the final target does not have parameters. - 1. *Nested Variational Inference*, - Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent - https://arxiv.org/abs/2106.11302 + [1] Nested Variational Inference, + Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent + https://arxiv.org/abs/2106.11302 Args: targets: a list of target programs @@ -270,9 +270,9 @@ def nvi_rkl(targets, forwards, reverses, *, num_targets=None): initial target to the final target. Here we use ELBO loss in the last step to also maximize likelihood in case there are parameters in the final target. - 1. *Nested Variational Inference*, - Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent - https://arxiv.org/abs/2106.11302 + [1] Nested Variational Inference, + Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent + https://arxiv.org/abs/2106.11302 Args: targets: a list of target programs @@ -302,12 +302,12 @@ def body_fun(i, q): def rws(target, proposal): """Reweighted Wake-Sleep. - 1. *Reweighted Wake-Sleep*, - Jörg Bornschein, Yoshua Bengio - https://arxiv.org/abs/1406.2751 - 2. *Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow*, - Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood - https://arxiv.org/abs/1805.10469 + [1] Reweighted Wake-Sleep, + Jörg Bornschein, Yoshua Bengio + https://arxiv.org/abs/1406.2751 + [2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow, + Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood + https://arxiv.org/abs/1805.10469 Args: target: the target program @@ -322,13 +322,13 @@ def rws(target, proposal): def svi(target, proposal): """Stochastic Variational Inference. - 1. *Auto-Encoding Variational Bayes*, - Diederik P Kingma, Max Welling - https://arxiv.org/abs/1312.6114 - 2. *Stochastic Backpropagation and Approximate Inference in Deep Generative - Models*, - Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra - https://arxiv.org/abs/1401.4082 + [1] Auto-Encoding Variational Bayes, + Diederik P Kingma, Max Welling + https://arxiv.org/abs/1312.6114 + [2] Stochastic Backpropagation and Approximate Inference in Deep Generative + Models, + Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra + https://arxiv.org/abs/1401.4082 Args: target: the target program @@ -343,9 +343,9 @@ def svi(target, proposal): def svi_iwae(target, proposal): """SVI with Important Weighted Autoencoder objective. - 1. *Importance Weighted Autoencoders*, - Yuri Burda, Roger Grosse, Ruslan Salakhutdinov - https://arxiv.org/abs/1509.00519 + [1] Importance Weighted Autoencoders, + Yuri Burda, Roger Grosse, Ruslan Salakhutdinov + https://arxiv.org/abs/1509.00519 Args: target: the target program @@ -360,10 +360,10 @@ def svi_iwae(target, proposal): def svi_stl(target, proposal): """SVI with Sticking-the-Landing objective. - 1. *Sticking the Landing: Simple, Lower-Variance Gradient Estimators for - Variational Inference*, - Geoffrey Roeder, Yuhuai Wu, David Duvenaud - https://arxiv.org/abs/1703.09194 + [1] Sticking the Landing: Simple, Lower-Variance Gradient Estimators for + Variational Inference, + Geoffrey Roeder, Yuhuai Wu, David Duvenaud + https://arxiv.org/abs/1703.09194 Args: target: the target program @@ -382,20 +382,20 @@ def vsmc(targets, proposals, *, num_targets=None): masking) during SMC steps. The targets can be filtering distributions or smoothing distributions (as in [4]). - 1. *Filtering Variational Objectives*, - Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, - Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh - https://arxiv.org/abs/1705.09279 - 2. *Auto-Encoding Sequential Monte Carlo*, - Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood - https://arxiv.org/abs/1705.10306 - 3. *Variational Sequential Monte Carlo*, - Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei - https://arxiv.org/abs/1705.11140 - 4. *Twisted Variational Sequential Monte Carlo*, - Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison, - Ryan P Adams, Yee Whye Teh - http://bayesiandeeplearning.org/2018/papers/111.pdf + [1] Filtering Variational Objectives, + Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, + Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh + https://arxiv.org/abs/1705.09279 + [2] Auto-Encoding Sequential Monte Carlo, + Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood + https://arxiv.org/abs/1705.10306 + [3] Variational Sequential Monte Carlo, + Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei + https://arxiv.org/abs/1705.11140 + [4] Twisted Variational Sequential Monte Carlo, + Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison, + Ryan P Adams, Yee Whye Teh + http://bayesiandeeplearning.org/2018/papers/111.pdf Args: targets: a list of target programs @@ -408,11 +408,9 @@ def vsmc(targets, proposals, *, num_targets=None): if _use_fori_loop(targets, num_targets, proposals): def body_fun(i, q): - return propose(targets(i), compose(proposals(i), resample(q))) + return compose(proposals(i + 1), resample(propose(targets(i), q))) - q = propose(targets(0), proposals(0)) - q = fori_loop(1, num_targets - 1, body_fun, q) - q = compose(proposals(num_targets - 1), resample(q)) + q = fori_loop(0, num_targets - 1, body_fun, proposals(0)) return propose(targets(num_targets - 1), q, loss_fn=iwae_loss) q = propose(targets[0], proposals[0]) diff --git a/coix/api.py b/coix/api.py index 861b9a4..2d09a4b 100644 --- a/coix/api.py +++ b/coix/api.py @@ -124,6 +124,9 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False): additional batch dimensions to the whole program by using `vmap`, e.g. `vmap(propose(p, q))`. + Note: We assume superfluous variables, which appear in `q` but not in `p`, + implicitly follow Delta distribution in `p`. + Args: p: a target program q: a proposal program @@ -176,10 +179,10 @@ def wrapped(*args, **kwargs): for name, lp in p_log_probs.items() if util.is_observed_site(p_trace[name]) or (name in q_trace) ) + # Note: We include superfluous variables, whose `name in p_trace`. q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for name, lp in q_log_probs.items() - if util.is_observed_site(q_trace[name]) or (name in p_trace) + for lp in q_log_probs.values() ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight diff --git a/coix/core.py b/coix/core.py index 9121552..e34abed 100644 --- a/coix/core.py +++ b/coix/core.py @@ -117,8 +117,6 @@ def desuffix(trace): new_trace = {} for name in trace: raw_name = names_to_raw_names[name] - if raw_name != name and isinstance(trace[name], dict): - trace[name]["suffix"] = True new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace diff --git a/coix/loss.py b/coix/loss.py index 2989971..bc7fda4 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -115,19 +115,29 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): return loss +def _proposal_and_target_sites(q_trace, p_trace): + """Gets current proposal sites and current target sites.""" + proposal_sites = [] + target_sites = [] + for name in p_trace: + if not name.endswith("_PREV_"): + target_sites.append(name) + if name in q_trace: + while name + "_PREV_" in q_trace: + name += "_PREV_" + proposal_sites.append(name) + if not any(name.endswith("_PREV_") for name in proposal_sites): + proposal_sites = [] + return proposal_sites, target_sites + + def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): - """Forward KL objective. Here we do not optimize p.""" - del p_trace + """Forward KL objective.""" batch_ndims = incoming_log_weight.ndim q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites = [ - name - for name, site in q_trace.items() - if name.endswith("_PREV_") - or (isinstance(site, dict) and "suffix" in site) - ] + proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace) proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -170,12 +180,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites = [ - name - for name, site in q_trace.items() - if name.endswith("_PREV_") - or (isinstance(site, dict) and "suffix" in site) - ] + proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -185,7 +190,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if not name.endswith("_PREV_") + if name in target_sites ) w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) @@ -208,12 +213,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites = [ - name - for name, site in q_trace.items() - if name.endswith("_PREV_") - or (isinstance(site, dict) and "suffix" in site) - ] + proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if not name.endswith("_PREV_") + if name in target_sites ) surrogate_loss = (target_lp - proposal_lp) + forward_lp diff --git a/coix/loss_test.py b/coix/loss_test.py index d822860..46fde42 100644 --- a/coix/loss_test.py +++ b/coix/loss_test.py @@ -25,7 +25,7 @@ } q_trace = { "x": {"log_prob": np.ones((3, 2))}, - "y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True}, + "y": {"log_prob": np.array([1.0, 1.0, 0.0])}, "x_PREV_": {"log_prob": np.full((3, 2), 3.0)}, } incoming_weight = np.zeros(3) diff --git a/coix/util.py b/coix/util.py index f356513..2ebf2ac 100644 --- a/coix/util.py +++ b/coix/util.py @@ -223,6 +223,32 @@ def step_fn(params, opt_state, *args, **kwargs): return params, metrics +def _remove_suffix(name): + i = 0 + while name.endswith("_PREV_"): + i += len("_PREV_") + name = name[: -len("_PREV_")] + return name, i + + +def desuffix(trace): + """Remove unnecessary suffix terms added to the trace.""" + names_to_raw_names = {} + num_suffix_min = {} + for name in trace: + raw_name, num_suffix = _remove_suffix(name) + names_to_raw_names[name] = raw_name + if raw_name in num_suffix_min: + num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) + else: + num_suffix_min[raw_name] = num_suffix + new_trace = {} + for name in trace: + raw_name = names_to_raw_names[name] + new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] + return new_trace + + def get_batch_ndims(xs): """Gets the number of same-size leading dimensions of the elements in xs.""" if not xs: