Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 624292817
  • Loading branch information
fehiepsi authored and The coix Authors committed May 6, 2024
1 parent 199103f commit 171b0b4
Show file tree
Hide file tree
Showing 7 changed files with 118 additions and 93 deletions.
2 changes: 1 addition & 1 deletion coix/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,4 @@

# A new PyPI release will be pushed everytime `__version__` is increased
# When changing this, also update the CHANGELOG.md
__version__ = "0.1.0"
__version__ = "0.0.1"
128 changes: 63 additions & 65 deletions coix/algo.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,9 @@ def _use_fori_loop(targets, num_targets, *fns):
def aft(targets, flows, *, num_targets=None):
"""Annealed Flow Transport.
1. *Annealed Flow Transport Monte Carlo*,
Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet
https://arxiv.org/abs/2102.07501
[1] Annealed Flow Transport Monte Carlo,
Michael Arbel, Alexander G. D. G. Matthews, Arnaud Doucet
https://arxiv.org/abs/2102.07501
Args:
targets: a list of target programs
Expand All @@ -94,10 +94,10 @@ def body_fun(i, q):
def apgs(target, kernels, *, num_sweeps=1):
"""Amortized Population Gibbs Sampler.
1. *Amortized Population Gibbs Samplers with Neural Sufficient Statistics*,
Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de
Meent
https://arxiv.org/abs/1911.01382
[1] Amortized Population Gibbs Samplers with Neural Sufficient Statistics,
Hao Wu, Heiko Zimmermann, Eli Sennesh, Tuan Anh Le, Jan-Willem van de
Meent
https://arxiv.org/abs/1911.01382
Args:
target: the target program
Expand All @@ -123,13 +123,13 @@ def body_fn(_, q):
def dais(targets, momentum, leapfrog, refreshment, *, num_targets=None):
"""Differentiable Annealed Importance Sampling.
1. *MCMC Variational Inference via Uncorrected Hamiltonian Annealing*,
Tomas Geffner, Justin Domke
https://arxiv.org/abs/2107.04150
2. *Differentiable Annealed Importance Sampling and the Perils of Gradient
Noise*,
Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
https://arxiv.org/abs/2107.10211
[1] MCMC Variational Inference via Uncorrected Hamiltonian Annealing,
Tomas Geffner, Justin Domke
https://arxiv.org/abs/2107.04150
[2] Differentiable Annealed Importance Sampling and the Perils of Gradient
Noise,
Guodong Zhang, Kyle Hsu, Jianing Li, Chelsea Finn, Roger Grosse
https://arxiv.org/abs/2107.10211
Args:
targets: a list of target programs
Expand Down Expand Up @@ -166,9 +166,9 @@ def body_fun(i, q):
def nasmc(targets, proposals, *, num_targets=None):
"""Neural Adaptive Sequential Monte Carlo.
1. *Neural Adaptive Sequential Monte Carlo*,
Shixiang Gu, Zoubin Ghahramani, Richard E. Turner
https://arxiv.org/abs/1506.03338
[1] Neural Adaptive Sequential Monte Carlo,
Shixiang Gu, Zoubin Ghahramani, Richard E. Turner
https://arxiv.org/abs/1506.03338
Args:
targets: a list of target programs
Expand Down Expand Up @@ -196,10 +196,10 @@ def body_fun(i, q):
def nvi_avo(targets, forwards, reverses, *, num_targets=None):
"""AIS with Annealed Variational Objective.
1. *Improving Explorability in Variational Inference with Annealed Variational
Objectives*,
Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville
https://arxiv.org/abs/1809.01818
[1] Improving Explorability in Variational Inference with Annealed Variational
Objectives,
Chin-Wei Huang, Shawn Tan, Alexandre Lacoste, Aaron Courville
https://arxiv.org/abs/1809.01818
Args:
targets: a list of target programs
Expand Down Expand Up @@ -231,9 +231,9 @@ def nvi_fkl(targets, proposals, *, num_targets=None):
This is different from `nasmc`, where we assume that the targets are filtering
distributions. We also assume that the final target does not have parameters.
1. *Nested Variational Inference*,
Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent
https://arxiv.org/abs/2106.11302
[1] Nested Variational Inference,
Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent
https://arxiv.org/abs/2106.11302
Args:
targets: a list of target programs
Expand Down Expand Up @@ -270,9 +270,9 @@ def nvi_rkl(targets, forwards, reverses, *, num_targets=None):
initial target to the final target. Here we use ELBO loss in the last step
to also maximize likelihood in case there are parameters in the final target.
1. *Nested Variational Inference*,
Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent
https://arxiv.org/abs/2106.11302
[1] Nested Variational Inference,
Heiko Zimmermann, Hao Wu, Babak Esmaeili, Jan-Willem van de Meent
https://arxiv.org/abs/2106.11302
Args:
targets: a list of target programs
Expand Down Expand Up @@ -302,12 +302,12 @@ def body_fun(i, q):
def rws(target, proposal):
"""Reweighted Wake-Sleep.
1. *Reweighted Wake-Sleep*,
Jörg Bornschein, Yoshua Bengio
https://arxiv.org/abs/1406.2751
2. *Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow*,
Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood
https://arxiv.org/abs/1805.10469
[1] Reweighted Wake-Sleep,
Jörg Bornschein, Yoshua Bengio
https://arxiv.org/abs/1406.2751
[2] Revisiting Reweighted Wake-Sleep for Models with Stochastic Control Flow,
Tuan Anh Le, Adam R. Kosiorek, N. Siddharth, Yee Whye Teh, Frank Wood
https://arxiv.org/abs/1805.10469
Args:
target: the target program
Expand All @@ -322,13 +322,13 @@ def rws(target, proposal):
def svi(target, proposal):
"""Stochastic Variational Inference.
1. *Auto-Encoding Variational Bayes*,
Diederik P Kingma, Max Welling
https://arxiv.org/abs/1312.6114
2. *Stochastic Backpropagation and Approximate Inference in Deep Generative
Models*,
Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra
https://arxiv.org/abs/1401.4082
[1] Auto-Encoding Variational Bayes,
Diederik P Kingma, Max Welling
https://arxiv.org/abs/1312.6114
[2] Stochastic Backpropagation and Approximate Inference in Deep Generative
Models,
Danilo Jimenez Rezende, Shakir Mohamed, Daan Wierstra
https://arxiv.org/abs/1401.4082
Args:
target: the target program
Expand All @@ -343,9 +343,9 @@ def svi(target, proposal):
def svi_iwae(target, proposal):
"""SVI with Important Weighted Autoencoder objective.
1. *Importance Weighted Autoencoders*,
Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
https://arxiv.org/abs/1509.00519
[1] Importance Weighted Autoencoders,
Yuri Burda, Roger Grosse, Ruslan Salakhutdinov
https://arxiv.org/abs/1509.00519
Args:
target: the target program
Expand All @@ -360,10 +360,10 @@ def svi_iwae(target, proposal):
def svi_stl(target, proposal):
"""SVI with Sticking-the-Landing objective.
1. *Sticking the Landing: Simple, Lower-Variance Gradient Estimators for
Variational Inference*,
Geoffrey Roeder, Yuhuai Wu, David Duvenaud
https://arxiv.org/abs/1703.09194
[1] Sticking the Landing: Simple, Lower-Variance Gradient Estimators for
Variational Inference,
Geoffrey Roeder, Yuhuai Wu, David Duvenaud
https://arxiv.org/abs/1703.09194
Args:
target: the target program
Expand All @@ -382,20 +382,20 @@ def vsmc(targets, proposals, *, num_targets=None):
masking) during SMC steps. The targets can be filtering distributions or
smoothing distributions (as in [4]).
1. *Filtering Variational Objectives*,
Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess,
Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh
https://arxiv.org/abs/1705.09279
2. *Auto-Encoding Sequential Monte Carlo*,
Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood
https://arxiv.org/abs/1705.10306
3. *Variational Sequential Monte Carlo*,
Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei
https://arxiv.org/abs/1705.11140
4. *Twisted Variational Sequential Monte Carlo*,
Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison,
Ryan P Adams, Yee Whye Teh
http://bayesiandeeplearning.org/2018/papers/111.pdf
[1] Filtering Variational Objectives,
Chris J. Maddison, Dieterich Lawson, George Tucker, Nicolas Heess,
Mohammad Norouzi, Andriy Mnih, Arnaud Doucet, Yee Whye Teh
https://arxiv.org/abs/1705.09279
[2] Auto-Encoding Sequential Monte Carlo,
Tuan Anh Le, Maximilian Igl, Tom Rainforth, Tom Jin, Frank Wood
https://arxiv.org/abs/1705.10306
[3] Variational Sequential Monte Carlo,
Christian A. Naesseth, Scott W. Linderman, Rajesh Ranganath, David M. Blei
https://arxiv.org/abs/1705.11140
[4] Twisted Variational Sequential Monte Carlo,
Dieterich Lawson, George Tucker, Christian A Naesseth, Chris J Maddison,
Ryan P Adams, Yee Whye Teh
http://bayesiandeeplearning.org/2018/papers/111.pdf
Args:
targets: a list of target programs
Expand All @@ -408,11 +408,9 @@ def vsmc(targets, proposals, *, num_targets=None):
if _use_fori_loop(targets, num_targets, proposals):

def body_fun(i, q):
return propose(targets(i), compose(proposals(i), resample(q)))
return compose(proposals(i + 1), resample(propose(targets(i), q)))

q = propose(targets(0), proposals(0))
q = fori_loop(1, num_targets - 1, body_fun, q)
q = compose(proposals(num_targets - 1), resample(q))
q = fori_loop(0, num_targets - 1, body_fun, proposals(0))
return propose(targets(num_targets - 1), q, loss_fn=iwae_loss)

q = propose(targets[0], proposals[0])
Expand Down
7 changes: 5 additions & 2 deletions coix/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,9 @@ def propose(p, q, *, loss_fn=None, detach=False, chain=False):
additional batch dimensions to the whole program by using `vmap`, e.g.
`vmap(propose(p, q))`.
Note: We assume superfluous variables, which appear in `q` but not in `p`,
implicitly follow Delta distribution in `p`.
Args:
p: a target program
q: a proposal program
Expand Down Expand Up @@ -176,10 +179,10 @@ def wrapped(*args, **kwargs):
for name, lp in p_log_probs.items()
if util.is_observed_site(p_trace[name]) or (name in q_trace)
)
# Note: We include superfluous variables, whose `name in p_trace`.
q_log_weight = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in q_log_probs.items()
if util.is_observed_site(q_trace[name]) or (name in p_trace)
for lp in q_log_probs.values()
)
incremental_log_weight = p_log_weight - q_log_weight
log_weight = in_log_weight + incremental_log_weight
Expand Down
2 changes: 0 additions & 2 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,8 +117,6 @@ def desuffix(trace):
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
if raw_name != name and isinstance(trace[name], dict):
trace[name]["suffix"] = True
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace

Expand Down
44 changes: 22 additions & 22 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,19 +115,29 @@ def elbo_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
return loss


def _proposal_and_target_sites(q_trace, p_trace):
"""Gets current proposal sites and current target sites."""
proposal_sites = []
target_sites = []
for name in p_trace:
if not name.endswith("_PREV_"):
target_sites.append(name)
if name in q_trace:
while name + "_PREV_" in q_trace:
name += "_PREV_"
proposal_sites.append(name)
if not any(name.endswith("_PREV_") for name in proposal_sites):
proposal_sites = []
return proposal_sites, target_sites


def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
"""Forward KL objective. Here we do not optimize p."""
del p_trace
"""Forward KL objective."""
batch_ndims = incoming_log_weight.ndim
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]
proposal_sites, _ = _proposal_and_target_sites(q_trace, p_trace)

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand Down Expand Up @@ -170,12 +180,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -185,7 +190,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if not name.endswith("_PREV_")
if name in target_sites
)

w1 = jax.lax.stop_gradient(jax.nn.softmax(incoming_log_weight, axis=0))
Expand All @@ -208,12 +213,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
q_log_probs = {
name: util.get_site_log_prob(site) for name, site in q_trace.items()
}
proposal_sites = [
name
for name, site in q_trace.items()
if name.endswith("_PREV_")
or (isinstance(site, dict) and "suffix" in site)
]
proposal_sites, target_sites = _proposal_and_target_sites(q_trace, p_trace)

proposal_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
Expand All @@ -228,7 +228,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
target_lp = sum(
lp.reshape(lp.shape[:batch_ndims] + (-1,)).sum(-1)
for name, lp in p_log_probs.items()
if not name.endswith("_PREV_")
if name in target_sites
)

surrogate_loss = (target_lp - proposal_lp) + forward_lp
Expand Down
2 changes: 1 addition & 1 deletion coix/loss_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
}
q_trace = {
"x": {"log_prob": np.ones((3, 2))},
"y": {"log_prob": np.array([1.0, 1.0, 0.0]), "suffix": True},
"y": {"log_prob": np.array([1.0, 1.0, 0.0])},
"x_PREV_": {"log_prob": np.full((3, 2), 3.0)},
}
incoming_weight = np.zeros(3)
Expand Down
26 changes: 26 additions & 0 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,6 +223,32 @@ def step_fn(params, opt_state, *args, **kwargs):
return params, metrics


def _remove_suffix(name):
i = 0
while name.endswith("_PREV_"):
i += len("_PREV_")
name = name[: -len("_PREV_")]
return name, i


def desuffix(trace):
"""Remove unnecessary suffix terms added to the trace."""
names_to_raw_names = {}
num_suffix_min = {}
for name in trace:
raw_name, num_suffix = _remove_suffix(name)
names_to_raw_names[name] = raw_name
if raw_name in num_suffix_min:
num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix)
else:
num_suffix_min[raw_name] = num_suffix
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace


def get_batch_ndims(xs):
"""Gets the number of same-size leading dimensions of the elements in xs."""
if not xs:
Expand Down

0 comments on commit 171b0b4

Please sign in to comment.