diff --git a/sml/metrics/regression/regression.py b/sml/metrics/regression/regression.py index 503f1f52..e480be50 100644 --- a/sml/metrics/regression/regression.py +++ b/sml/metrics/regression/regression.py @@ -14,8 +14,6 @@ from sklearn import metrics import jax.numpy as jnp -import spu.spu_pb2 as spu_pb2 -import spu.utils.simulation as spsim def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power, d2_score=False): @@ -45,7 +43,7 @@ def _mean_tweedie_deviance(y_true, y_pred, sample_weight, power, d2_score=False) - temp_power * (y_true / temp - y_pred / (temp + 1)) ) - if d2_score & (sample_weight is None): + if d2_score and (sample_weight is None): # When weight is none and the d2 fraction is calculated, the numerator and denominator are divided by shape, and only sum is used return jnp.sum(dev) else: