diff --git a/sml/metrics/regression/regression_emul.py b/sml/metrics/regression/regression_emul.py index c0d6f185..6adc0808 100644 --- a/sml/metrics/regression/regression_emul.py +++ b/sml/metrics/regression/regression_emul.py @@ -33,7 +33,7 @@ def emul_d2_tweedie_score(mode: emulation.Mode.MULTIPROCESS): - power_list = [-1, 0 , 1, 2, 3] + power_list = [-1, 0, 1, 2, 3] weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])] # Test d2_tweedie_score @@ -49,6 +49,7 @@ def emul_d2_tweedie_score(mode: emulation.Mode.MULTIPROCESS): ) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4) + def emul_explained_variance_score(mode: emulation.Mode.MULTIPROCESS): weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])] @@ -68,6 +69,7 @@ def emul_explained_variance_score(mode: emulation.Mode.MULTIPROCESS): ) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4) + def emul_mean_squared_error(mode: emulation.Mode.MULTIPROCESS): weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])] @@ -83,6 +85,7 @@ def emul_mean_squared_error(mode: emulation.Mode.MULTIPROCESS): ) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4) + def emul_mean_poisson_deviance(mode: emulation.Mode.MULTIPROCESS): weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])] @@ -94,6 +97,7 @@ def emul_mean_poisson_deviance(mode: emulation.Mode.MULTIPROCESS): spu_result = emulator.run(mean_poisson_deviance)(y_true, y_pred, weight) np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4) + def emul_mean_gamma_deviance(mode: emulation.Mode.MULTIPROCESS): weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]