Skip to content

Commit

Permalink
Update regression_emul.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tarantula-leo authored Dec 6, 2023
1 parent c9aa6c3 commit e1c5fb8
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion sml/metrics/regression/regression_emul.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@


def emul_d2_tweedie_score(mode: emulation.Mode.MULTIPROCESS):
power_list = [-1, 0 , 1, 2, 3]
power_list = [-1, 0, 1, 2, 3]
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test d2_tweedie_score
Expand All @@ -49,6 +49,7 @@ def emul_d2_tweedie_score(mode: emulation.Mode.MULTIPROCESS):
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)


def emul_explained_variance_score(mode: emulation.Mode.MULTIPROCESS):
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

Expand All @@ -68,6 +69,7 @@ def emul_explained_variance_score(mode: emulation.Mode.MULTIPROCESS):
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)


def emul_mean_squared_error(mode: emulation.Mode.MULTIPROCESS):
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

Expand All @@ -83,6 +85,7 @@ def emul_mean_squared_error(mode: emulation.Mode.MULTIPROCESS):
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)


def emul_mean_poisson_deviance(mode: emulation.Mode.MULTIPROCESS):
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

Expand All @@ -94,6 +97,7 @@ def emul_mean_poisson_deviance(mode: emulation.Mode.MULTIPROCESS):
spu_result = emulator.run(mean_poisson_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)


def emul_mean_gamma_deviance(mode: emulation.Mode.MULTIPROCESS):
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

Expand Down

0 comments on commit e1c5fb8

Please sign in to comment.