Skip to content

Commit

Permalink
Update regression_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tarantula-leo authored Dec 6, 2023
1 parent e1c5fb8 commit 5d8c587
Showing 1 changed file with 38 additions and 12 deletions.
50 changes: 38 additions & 12 deletions sml/metrics/regression/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,12 @@ def test_d2_tweedie_score(self):
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

power_list = [-1, 0 , 1, 2, 3]
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]
power_list = [-1, 0, 1, 2, 3]
weight_list = [
None,
jnp.array([0.5, 0.5, 0.5, 0.5]),
jnp.array([0.5, 1, 2, 0.5]),
]

# Test d2_tweedie_score
y_true = jnp.array([0.5, 1, 2.5, 7])
Expand All @@ -61,7 +65,11 @@ def test_explained_variance_score(self):
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]
weight_list = [
None,
jnp.array([0.5, 0.5, 0.5, 0.5]),
jnp.array([0.5, 1, 2, 0.5]),
]

# Test explained_variance_score
y_true = jnp.array([3, -0.5, 2, 7])
Expand All @@ -74,17 +82,21 @@ def test_explained_variance_score(self):
multioutput="variance_weighted",
force_finite=True,
)
spu_result = spsim.sim_jax(sim, explained_variance_score, static_argnums=(3,))(
y_true, y_pred, weight, "variance_weighted"
)
spu_result = spsim.sim_jax(
sim, explained_variance_score, static_argnums=(3,)
)(y_true, y_pred, weight, "variance_weighted")
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_mean_squared_error(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]
weight_list = [
None,
jnp.array([0.5, 0.5, 0.5, 0.5]),
jnp.array([0.5, 1, 2, 0.5]),
]

# Test mean_squared_error
y_true = jnp.array([3, -0.5, 2, 7])
Expand All @@ -103,28 +115,42 @@ def test_mean_poisson_deviance(self):
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]
weight_list = [
None,
jnp.array([0.5, 0.5, 0.5, 0.5]),
jnp.array([0.5, 1, 2, 0.5]),
]

# Test mean_poisson_deviance
y_true = jnp.array([2, 0, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
for weight in weight_list:
sk_result = metrics.mean_poisson_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_poisson_deviance)(y_true, y_pred, weight)
sk_result = metrics.mean_poisson_deviance(
y_true, y_pred, sample_weight=weight
)
spu_result = spsim.sim_jax(sim, mean_poisson_deviance)(
y_true, y_pred, weight
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_mean_gamma_deviance(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]
weight_list = [
None,
jnp.array([0.5, 0.5, 0.5, 0.5]),
jnp.array([0.5, 1, 2, 0.5]),
]

# Test mean_gamma_deviance
y_true = jnp.array([2, 0.5, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
for weight in weight_list:
sk_result = metrics.mean_gamma_deviance(y_true, y_pred, sample_weight=weight)
sk_result = metrics.mean_gamma_deviance(
y_true, y_pred, sample_weight=weight
)
spu_result = spsim.sim_jax(sim, mean_gamma_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

Expand Down

0 comments on commit 5d8c587

Please sign in to comment.