diff --git a/sml/metrics/regression/regression.py b/sml/metrics/regression/regression.py index 43047342..0ddb251b 100644 --- a/sml/metrics/regression/regression.py +++ b/sml/metrics/regression/regression.py @@ -48,7 +48,7 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power): ) return jnp.average(dev, weights=sample_weight) - + def _d2_tweedie_score(y_true, y_pred, sample_weight, power, re_use=None): p = power re_value = [] @@ -147,6 +147,7 @@ def explained_variance_score( return jnp.average(output_scores, weights=avg_weights) + def mean_squared_error( y_true, y_pred, sample_weight=None, multioutput="uniform_average", squared=True ): @@ -167,5 +168,6 @@ def mean_squared_error( def mean_poisson_deviance(y_true, y_pred, sample_weight=None): return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=1) + def mean_gamma_deviance(y_true, y_pred, sample_weight=None): return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2)