-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add test to make sure the main metrics class works
- Loading branch information
1 parent
10fe8e2
commit b61744b
Showing
2 changed files
with
45 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
from test_metrics import run_test_metrics | ||
|
||
def main(): | ||
""" Runs generation and analysis steps back-to-back. """ | ||
run_test_metrics() | ||
|
||
if __name__ == "__main__": | ||
main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
""" | ||
This module contains test functions for calibration metrics. | ||
The main test function generates synthetic data and tests the CalibrationMetrics class. | ||
""" | ||
|
||
import numpy as np | ||
from calzone.metrics import CalibrationMetrics | ||
|
||
def run_test_metrics(): | ||
""" | ||
Test function for calibration metrics. | ||
Generates synthetic binary classification data and tests the CalibrationMetrics class. | ||
The test includes: | ||
- Generating random binary labels using binomial distribution | ||
- Creating probability predictions using beta distribution | ||
- Testing the CalibrationMetrics class with the generated data | ||
Notice that the test is not exhaustive and doesn't test for correctness of the metrics. | ||
Returns: | ||
None | ||
Raises: | ||
AssertionError: If the results from CalibrationMetrics are not in dictionary format | ||
""" | ||
# Generate sample data | ||
np.random.seed(123) | ||
n_samples = 1000 | ||
y_true = np.random.binomial(1, 0.3, n_samples) | ||
y_proba = np.zeros((n_samples,2)) | ||
y_proba[:, 1] = np.random.beta(0.5, 0.5, (n_samples, )) | ||
y_proba[:, 0] = 1 - y_proba[:, 1] # Ensure probabilities sum to 1 | ||
|
||
# Test CalibrationMetrics class | ||
metrics = CalibrationMetrics(class_to_calculate=1) | ||
results = metrics.calculate_metrics(y_true, y_proba, metrics='all') | ||
assert isinstance(results, dict) |