Skip to content

Commit

Permalink
Update regression_test.py
Browse files Browse the repository at this point in the history
  • Loading branch information
tarantula-leo authored Dec 6, 2023
1 parent 7cd6077 commit 32912d2
Showing 1 changed file with 66 additions and 34 deletions.
100 changes: 66 additions & 34 deletions sml/metrics/regression/regression_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,98 @@


class UnitTests(unittest.TestCase):
def test_regression(self):
def test_d2_tweedie_score(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

power_list = [-1, 0 , 1, 2, 3]
weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test d2_tweedie_score
y_true = jnp.array([0.5, 1, 2.5, 7])
y_pred = jnp.array([1, 1, 5, 3.5])
weight = None
sk_result = metrics.d2_tweedie_score(
y_true, y_pred, sample_weight=weight, power=1
)
spu_result = spsim.sim_jax(sim, d2_tweedie_score, static_argnums=(3,))(
y_true, y_pred, weight, 1
for p in power_list:
for weight in weight_list:
sk_result = metrics.d2_tweedie_score(
y_true, y_pred, sample_weight=weight, power=p
)
spu_result = spsim.sim_jax(sim, d2_tweedie_score, static_argnums=(3,))(
y_true, y_pred, weight, p
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_explained_variance_score(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test explained_variance_score
y_true = jnp.array([3, -0.5, 2, 7])
y_pred = jnp.array([2.5, 0.0, 2, 8])
weight = None
sk_result = metrics.explained_variance_score(
y_true,
y_pred,
sample_weight=weight,
multioutput="variance_weighted",
force_finite=True,
)
spu_result = spsim.sim_jax(sim, explained_variance_score, static_argnums=(3,))(
y_true, y_pred, weight, "variance_weighted"
for weight in weight_list:
sk_result = metrics.explained_variance_score(
y_true,
y_pred,
sample_weight=weight,
multioutput="variance_weighted",
force_finite=True,
)
spu_result = spsim.sim_jax(sim, explained_variance_score, static_argnums=(3,))(
y_true, y_pred, weight, "variance_weighted"
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_mean_squared_error(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test mean_squared_error
y_true = jnp.array([3, -0.5, 2, 7])
y_pred = jnp.array([2.5, 0.0, 2, 8])
weight = None
sk_result = metrics.mean_squared_error(
y_true, y_pred, sample_weight=None, squared=False
)
spu_result = spsim.sim_jax(sim, mean_squared_error, static_argnums=(3, 4))(
y_true, y_pred, weight, "uniform_average", False
for weight in weight_list:
sk_result = metrics.mean_squared_error(
y_true, y_pred, sample_weight=weight, squared=False
)
spu_result = spsim.sim_jax(sim, mean_squared_error, static_argnums=(3, 4))(
y_true, y_pred, weight, "uniform_average", False
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_mean_poisson_deviance(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test mean_poisson_deviance
y_true = jnp.array([2, 0, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
weight = None
sk_result = metrics.mean_poisson_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_poisson_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)
for weight in weight_list:
sk_result = metrics.mean_poisson_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_poisson_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)

def test_mean_gamma_deviance(self):
sim = spsim.Simulator.simple(
3, spu_pb2.ProtocolKind.ABY3, spu_pb2.FieldType.FM128
)

weight_list = [None, jnp.array([0.5, 0.5, 0.5, 0.5]), jnp.array([0.5, 1, 2, 0.5])]

# Test mean_gamma_deviance
y_true = jnp.array([2, 0.5, 1, 4])
y_pred = jnp.array([0.5, 0.5, 2.0, 2.0])
weight = None
sk_result = metrics.mean_gamma_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_gamma_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)
for weight in weight_list:
sk_result = metrics.mean_gamma_deviance(y_true, y_pred, sample_weight=weight)
spu_result = spsim.sim_jax(sim, mean_gamma_deviance)(y_true, y_pred, weight)
np.testing.assert_allclose(sk_result, spu_result, rtol=0, atol=1e-4)


if __name__ == "__main__":
Expand Down

0 comments on commit 32912d2

Please sign in to comment.