diff --git a/coix/__init__.py b/coix/__init__.py index 7188797..f52e881 100644 --- a/coix/__init__.py +++ b/coix/__init__.py @@ -55,4 +55,4 @@ # A new PyPI release will be pushed everytime `__version__` is increased # When changing this, also update the CHANGELOG.md -__version__ = "0.0.1" +__version__ = "0.1.0" diff --git a/coix/algo.py b/coix/algo.py index f600fbe..2da2eb2 100644 --- a/coix/algo.py +++ b/coix/algo.py @@ -65,9 +65,9 @@ def _use_fori_loop(targets, num_targets, *fns): def aft(targets, flows, *, num_targets=None): """Annealed Flow Transport. - [1] Annealed Flow Transport Monte Carlo, - Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet - https://arxiv.org/abs/2102.07501 + 1. *Annealed Flow Transport Monte Carlo*, + Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet + https://arxiv.org/abs/2102.07501 Args: targets: a list of target programs @@ -94,10 +94,10 @@ def body_fun(i, q): def apgs(target, kernels, *, num_sweeps=1): """Amortized Population Gibbs Sampler. - [1] Amortized Population Gibbs Samplers with Neural Sufficient Statistics, - Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de - Meent - https://arxiv.org/abs/1911.01382 + 1. *Amortized Population Gibbs Samplers with Neural Sufficient Statistics*, + Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de + Meent + https://arxiv.org/abs/1911.01382 Args: target: the target program @@ -123,13 +123,13 @@ def body_fn(_, q): def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None): """Differentiable Annealed Importance Sampling. - [1] MCMC Variational Inference via Uncorrected Hamiltonian Annealing, - Tomas Geffner, Justin Domke - https://arxiv.org/abs/2107.04150 - [2] Differentiable Annealed Importance Sampling and the Perils of Gradient - Noise, - Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse - https://arxiv.org/abs/2107.10211 + 1. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*, + Tomas Geffner, Justin Domke + https://arxiv.org/abs/2107.04150 + 2. *Differentiable Annealed Importance Sampling and the Perils of Gradient + Noise*, + Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse + https://arxiv.org/abs/2107.10211 Args: targets: a list of target programs @@ -166,9 +166,9 @@ def body_fun(i, q): def nasmc(targets, proposals, *, num_targets=None): """Neural Adaptive Sequential Monte Carlo. - [1] Neural Adaptive Sequential Monte Carlo, - Shixiang Gu, Zoubin Ghahramani, Richard E. Turner - https://arxiv.org/abs/1506.03338 + 1. *Neural Adaptive Sequential Monte Carlo*, + Shixiang Gu, Zoubin Ghahramani, Richard E. Turner + https://arxiv.org/abs/1506.03338 Args: targets: a list of target programs @@ -196,10 +196,10 @@ def body_fun(i, q): def nvi_avo(targets, forwards, reverses, *, num_targets=None): """AIS with Annealed Variational Objective. - [1] Improving Explorability in Variational Inference with Annealed Variational - Objectives, - Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville - https://arxiv.org/abs/1809.01818 + 1. *Improving Explorability in Variational Inference with Annealed Variational + Objectives*, + Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville + https://arxiv.org/abs/1809.01818 Args: targets: a list of target programs @@ -231,9 +231,9 @@ def nvi_fkl(targets, proposals, *, num_targets=None): This is different from `nasmc`, where we assume that the targets are filtering distributions. We also assume that the final target does not have parameters. - [1] Nested Variational Inference, - Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent - https://arxiv.org/abs/2106.11302 + 1. *Nested Variational Inference*, + Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent + https://arxiv.org/abs/2106.11302 Args: targets: a list of target programs @@ -270,9 +270,9 @@ def nvi_rkl(targets, forwards, reverses, *, num_targets=None): initial target to the final target. Here we use ELBO loss in the last step to also maximize likelihood in case there are parameters in the final target. - [1] Nested Variational Inference, - Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent - https://arxiv.org/abs/2106.11302 + 1. *Nested Variational Inference*, + Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent + https://arxiv.org/abs/2106.11302 Args: targets: a list of target programs @@ -302,12 +302,12 @@ def body_fun(i, q): def rws(target, proposal): """Reweighted Wake-Sleep. - [1] Reweighted Wake-Sleep, - Jörg Bornschein, Yoshua Bengio - https://arxiv.org/abs/1406.2751 - [2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow, - Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood - https://arxiv.org/abs/1805.10469 + 1. *Reweighted Wake-Sleep*, + Jörg Bornschein, Yoshua Bengio + https://arxiv.org/abs/1406.2751 + 2. *Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow*, + Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood + https://arxiv.org/abs/1805.10469 Args: target: the target program @@ -322,13 +322,13 @@ def rws(target, proposal): def svi(target, proposal): """Stochastic Variational Inference. - [1] Auto-Encoding Variational Bayes, - Diederik P Kingma, Max Welling - https://arxiv.org/abs/1312.6114 - [2] Stochastic Backpropagation and Approximate Inference in Deep Generative - Models, - Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra - https://arxiv.org/abs/1401.4082 + 1. *Auto-Encoding Variational Bayes*, + Diederik P Kingma, Max Welling + https://arxiv.org/abs/1312.6114 + 2. *Stochastic Backpropagation and Approximate Inference in Deep Generative + Models*, + Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra + https://arxiv.org/abs/1401.4082 Args: target: the target program @@ -343,9 +343,9 @@ def svi(target, proposal): def svi_iwae(target, proposal): """SVI with Important Weighted Autoencoder objective. - [1] Importance Weighted Autoencoders, - Yuri Burda, Roger Grosse, Ruslan Salakhutdinov - https://arxiv.org/abs/1509.00519 + 1. *Importance Weighted Autoencoders*, + Yuri Burda, Roger Grosse, Ruslan Salakhutdinov + https://arxiv.org/abs/1509.00519 Args: target: the target program @@ -360,10 +360,10 @@ def svi_iwae(target, proposal): def svi_stl(target, proposal): """SVI with Sticking-the-Landing objective. - [1] Sticking the Landing: Simple, Lower-Variance Gradient Estimators for - Variational Inference, - Geoffrey Roeder, Yuhuai Wu, David Duvenaud - https://arxiv.org/abs/1703.09194 + 1. *Sticking the Landing: Simple, Lower-Variance Gradient Estimators for + Variational Inference*, + Geoffrey Roeder, Yuhuai Wu, David Duvenaud + https://arxiv.org/abs/1703.09194 Args: target: the target program @@ -382,20 +382,20 @@ def vsmc(targets, proposals, *, num_targets=None): masking) during SMC steps. The targets can be filtering distributions or smoothing distributions (as in [4]). - [1] Filtering Variational Objectives, - Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, - Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh - https://arxiv.org/abs/1705.09279 - [2] Auto-Encoding Sequential Monte Carlo, - Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood - https://arxiv.org/abs/1705.10306 - [3] Variational Sequential Monte Carlo, - Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei - https://arxiv.org/abs/1705.11140 - [4] Twisted Variational Sequential Monte Carlo, - Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison, - Ryan P Adams, Yee Whye Teh - http://bayesiandeeplearning.org/2018/papers/111.pdf + 1. *Filtering Variational Objectives*, + Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess, + Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh + https://arxiv.org/abs/1705.09279 + 2. *Auto-Encoding Sequential Monte Carlo*, + Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood + https://arxiv.org/abs/1705.10306 + 3. *Variational Sequential Monte Carlo*, + Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei + https://arxiv.org/abs/1705.11140 + 4. *Twisted Variational Sequential Monte Carlo*, + Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison, + Ryan P Adams, Yee Whye Teh + http://bayesiandeeplearning.org/2018/papers/111.pdf Args: targets: a list of target programs @@ -408,9 +408,11 @@ def vsmc(targets, proposals, *, num_targets=None): if _use_fori_loop(targets, num_targets, proposals): def body_fun(i, q): - return compose(proposals(i + 1), resample(propose(targets(i), q))) + return propose(targets(i), compose(proposals(i), resample(q))) - q = fori_loop(0, num_targets - 1, body_fun, proposals(0)) + q = propose(targets(0), proposals(0)) + q = fori_loop(1, num_targets - 1, body_fun, q) + q = compose(proposals(num_targets - 1), resample(q)) return propose(targets(num_targets - 1), q, loss_fn=iwae_loss) q = propose(targets[0], proposals[0]) diff --git a/coix/api.py b/coix/api.py index 2d09a4b..861b9a4 100644 --- a/coix/api.py +++ b/coix/api.py @@ -124,9 +124,6 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False): additional batch dimensions to the whole program by using `vmap`, e.g. `vmap(propose(p, q))`. - Note: We assume superfluous variables, which appear in `q` but not in `p`, - implicitly follow Delta distribution in `p`. - Args: p: a target program q: a proposal program @@ -179,10 +176,10 @@ def wrapped(*args, **kwargs): for name, lp in p_log_probs.items() if util.is_observed_site(p_trace[name]) or (name in q_trace) ) - # Note: We include superfluous variables, whose `name in p_trace`. q_log_weight = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) - for lp in q_log_probs.values() + for name, lp in q_log_probs.items() + if util.is_observed_site(q_trace[name]) or (name in p_trace) ) incremental_log_weight = p_log_weight - q_log_weight log_weight = in_log_weight + incremental_log_weight diff --git a/coix/core.py b/coix/core.py index e34abed..9121552 100644 --- a/coix/core.py +++ b/coix/core.py @@ -117,6 +117,8 @@ def desuffix(trace): new_trace = {} for name in trace: raw_name = names_to_raw_names[name] + if raw_name != name and isinstance(trace[name], dict): + trace[name]["suffix"] = True new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] return new_trace diff --git a/coix/loss.py b/coix/loss.py index bc7fda4..2989971 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -115,29 +115,19 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): return loss -def _proposal_and_target_sites(q_trace, p_trace): - """Gets current proposal sites and current target sites.""" - proposal_sites = [] - target_sites = [] - for name in p_trace: - if not name.endswith("_PREV_"): - target_sites.append(name) - if name in q_trace: - while name + "_PREV_" in q_trace: - name += "_PREV_" - proposal_sites.append(name) - if not any(name.endswith("_PREV_") for name in proposal_sites): - proposal_sites = [] - return proposal_sites, target_sites - - def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): - """Forward KL objective.""" + """Forward KL objective. Here we do not optimize p.""" + del p_trace batch_ndims = incoming_log_weight.ndim q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -180,7 +170,12 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -190,7 +185,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0)) @@ -213,7 +208,12 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): q_log_probs = { name: util.get_site_log_prob(site) for name, site in q_trace.items() } - proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace) + proposal_sites = [ + name + for name, site in q_trace.items() + if name.endswith("_PREV_") + or (isinstance(site, dict) and "suffix" in site) + ] proposal_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight): target_lp = sum( lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1) for name, lp in p_log_probs.items() - if name in target_sites + if not name.endswith("_PREV_") ) surrogate_loss = (target_lp - proposal_lp) + forward_lp diff --git a/coix/loss_test.py b/coix/loss_test.py index 46fde42..d822860 100644 --- a/coix/loss_test.py +++ b/coix/loss_test.py @@ -25,7 +25,7 @@ } q_trace = { "x": {"log_prob": np.ones((3, 2))}, - "y": {"log_prob": np.array([1.0, 1.0, 0.0])}, + "y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True}, "x_PREV_": {"log_prob": np.full((3, 2), 3.0)}, } incoming_weight = np.zeros(3) diff --git a/coix/util.py b/coix/util.py index 2ebf2ac..f356513 100644 --- a/coix/util.py +++ b/coix/util.py @@ -223,32 +223,6 @@ def step_fn(params, opt_state, *args, **kwargs): return params, metrics -def _remove_suffix(name): - i = 0 - while name.endswith("_PREV_"): - i += len("_PREV_") - name = name[: -len("_PREV_")] - return name, i - - -def desuffix(trace): - """Remove unnecessary suffix terms added to the trace.""" - names_to_raw_names = {} - num_suffix_min = {} - for name in trace: - raw_name, num_suffix = _remove_suffix(name) - names_to_raw_names[name] = raw_name - if raw_name in num_suffix_min: - num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix) - else: - num_suffix_min[raw_name] = num_suffix - new_trace = {} - for name in trace: - raw_name = names_to_raw_names[name] - new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name] - return new_trace - - def get_batch_ndims(xs): """Gets the number of same-size leading dimensions of the elements in xs.""" if not xs: