Skip to content

Commit

Permalink
add desuffix to trace to extract forward/reverse variables
Browse files Browse the repository at this point in the history
  • Loading branch information
fehiepsi committed May 3, 2024
1 parent 9f8a493 commit 63d94b1
Show file tree
Hide file tree
Showing 3 changed files with 5 additions and 29 deletions.
2 changes: 2 additions & 0 deletions coix/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,8 @@ def desuffix(trace):
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
if raw_name != name and isinstance(trace[name], dict):
trace[name]["desuffix"] = True
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace

Expand Down
6 changes: 3 additions & 3 deletions coix/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def fkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
if isinstance(site, dict) and "desuffix" in site
]

proposal_lp = sum(
Expand Down Expand Up @@ -171,7 +171,7 @@ def rkl_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
if isinstance(site, dict) and "desuffix" in site
]

proposal_lp = sum(
Expand Down Expand Up @@ -208,7 +208,7 @@ def rws_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight):
proposal_sites = [
name
for name, site in q_trace.items()
if util.is_observed_site(site) or name.endswith("_PREV_")
if isinstance(site, dict) and "desuffix" in site
]

proposal_lp = sum(
Expand Down
26 changes: 0 additions & 26 deletions coix/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,32 +223,6 @@ def step_fn(params, opt_state, *args, **kwargs):
return params, metrics


def _remove_suffix(name):
i = 0
while name.endswith("_PREV_"):
i += len("_PREV_")
name = name[: -len("_PREV_")]
return name, i


def desuffix(trace):
"""Remove unnecessary suffix terms added to the trace."""
names_to_raw_names = {}
num_suffix_min = {}
for name in trace:
raw_name, num_suffix = _remove_suffix(name)
names_to_raw_names[name] = raw_name
if raw_name in num_suffix_min:
num_suffix_min[raw_name] = min(num_suffix_min[raw_name], num_suffix)
else:
num_suffix_min[raw_name] = num_suffix
new_trace = {}
for name in trace:
raw_name = names_to_raw_names[name]
new_trace[name[: len(name) - num_suffix_min[raw_name]]] = trace[name]
return new_trace


def get_batch_ndims(xs):
"""Gets the number of same-size leading dimensions of the elements in xs."""
if not xs:
Expand Down

0 comments on commit 63d94b1

Please sign in to comment.