diff --git a/coix/loss.py b/coix/loss.py index cbef1c5..1a10fec 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -29,7 +29,13 @@ ] -def apg_loss(q_trace, p_trace, incoming_log_weight, incremental_log_weight, aggregate=False): +def apg_loss( + q_trace, + p_trace, + incoming_log_weight, + incremental_log_weight, + aggregate=False, +): """RWS objective that exploits conditional dependency.""" del incoming_log_weight, incremental_log_weight p_log_probs = {