Skip to content

Commit

Permalink
Update regression_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tarantula-leo authored Dec 4, 2023
1 parent 7ced559 commit 11c66b4
Showing 1 changed file with 24 additions and 8 deletions.
32 changes: 24 additions & 8 deletions sml/metrics/regression/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,37 +44,53 @@ def test_regression(self):
y_true = jnp.array([0.5, 1, 2.5, 7])
y_pred = jnp.array([1, 1, 5, 3.5])
weight = None
sk_result = metrics.d2_tweedie_score(y_true, y_pred, sample_weight=weight, power=1)
spu_result = spsim.sim_jax(sim, d2_tweedie_score, static_argnums=(3, ))(y_true, y_pred, weight, 1)
sk_result = metrics.d2_tweedie_score(
y_true, y_pred, sample_weight=weight, power=1
)
spu_result = spsim.sim_jax(sim, d2_tweedie_score, static_argnums=(3,))(
y_true, y_pred, weight, 1
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

# Test explained_variance_score
y_true = jnp.array([3, -0.5, 2, 7])
y_pred = jnp.array([2.5, 0.0, 2, 8])
weight = None
sk_result = metrics.explained_variance_score(y_true, y_pred, sample_weight=weight, multioutput="variance_weighted", force_finite=True)
spu_result = spsim.sim_jax(sim, explained_variance_score, static_argnums=(3, ))(y_true, y_pred, weight, "variance_weighted")
sk_result = metrics.explained_variance_score(
y_true,
y_pred,
sample_weight=weight,
multioutput="variance_weighted",
force_finite=True,
)
spu_result = spsim.sim_jax(sim, explained_variance_score, static_argnums=(3,))(
y_true, y_pred, weight, "variance_weighted"
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

# Test mean_squared_error
y_true = jnp.array([3, -0.5, 2, 7])
y_pred = jnp.array([2.5, 0.0, 2, 8])
weight = None
sk_result = metrics.mean_squared_error(y_true, y_pred, sample_weight=None, squared=False)
spu_result = spsim.sim_jax(sim, mean_squared_error, static_argnums=(3, 4))(y_true, y_pred, weight, "uniform_average", False)
sk_result = metrics.mean_squared_error(
y_true, y_pred, sample_weight=None, squared=False
)
spu_result = spsim.sim_jax(sim, mean_squared_error, static_argnums=(3, 4))(
y_true, y_pred, weight, "uniform_average", False
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

# Test mean_poisson_deviance
y_true = jnp.array([2, 0, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2., 2.])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
weight = None
sk_result = metrics.mean_poisson_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_poisson_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

# Test mean_gamma_deviance
y_true = jnp.array([2, 0.5, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2., 2.])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
weight = None
sk_result = metrics.mean_gamma_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_gamma_deviance)(y_true, y_pred, weight)
Expand Down

0 comments on commit 11c66b4

Please sign in to comment.