diff --git a/coix/loss.py b/coix/loss.py index 1a10fec..995ec98 100644 --- a/coix/loss.py +++ b/coix/loss.py @@ -34,7 +34,7 @@ def apg_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """RWS objective that exploits conditional dependency.""" del incoming_log_weight, incremental_log_weight @@ -104,7 +104,7 @@ def avo_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Annealed Variational Objective.""" del q_trace, p_trace @@ -124,7 +124,7 @@ def elbo_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Evidence Lower Bound objective.""" del q_trace, p_trace @@ -144,7 +144,7 @@ def fkl_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Forward KL objective. Here we do not optimize p.""" del p_trace @@ -185,7 +185,7 @@ def iwae_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Importance Weighted Autoencoder objective.""" del q_trace, p_trace @@ -206,7 +206,7 @@ def rkl_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Reverse KL objective.""" batch_ndims = incoming_log_weight.ndim @@ -252,7 +252,7 @@ def rws_loss( p_trace, incoming_log_weight, incremental_log_weight, - aggregate=False, + aggregate=True, ): """Reweighted Wake-Sleep objective.""" batch_ndims = incoming_log_weight.ndim