Skip to content

Commit

Permalink
Update regression.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tarantula-leo authored Dec 4, 2023
1 parent 7a2a5e7 commit 6e39c8a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion sml/metrics/regression/regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power):
)
return jnp.average(dev, weights=sample_weight)


def _d2_tweedie_score(y_true, y_pred, sample_weight, power, re_use=None):
p = power
re_value = []
Expand Down Expand Up @@ -147,6 +147,7 @@ def explained_variance_score(

return jnp.average(output_scores, weights=avg_weights)


def mean_squared_error(
y_true, y_pred, sample_weight=None, multioutput="uniform_average", squared=True
):
Expand All @@ -167,5 +168,6 @@ def mean_squared_error(
def mean_poisson_deviance(y_true, y_pred, sample_weight=None):
return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=1)


def mean_gamma_deviance(y_true, y_pred, sample_weight=None):
return _mean_tweedie_deviance(y_true, y_pred, sample_weight=sample_weight, power=2)

0 comments on commit 6e39c8a

Please sign in to comment.