Skip to content

Commit

Permalink
ENH add scoring.decompose (#28)
Browse files Browse the repository at this point in the history
* ENH add scoring.decompose
* DOC add score decomposition to example
  • Loading branch information
lorentzenchr authored Feb 26, 2023
1 parent 71b4abf commit d54a1b8
Show file tree
Hide file tree
Showing 6 changed files with 517 additions and 12 deletions.
190 changes: 187 additions & 3 deletions docs/examples/regression_on_workers_compensation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -1273,13 +1273,197 @@
"ax.set_title(ax.get_title() + \" GLM_Poisson\")"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "26e23556",
"metadata": {},
"source": [
"## 5. Predictive Model Performance\n",
"We use the Gamma deviance and evaluate on the test set.\n",
"On top, we additively decompose the scores in terms of miscalibration (smaller is better), discrimination (larger is better) and uncertainty (property of the data, same for all models)."
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 26,
"id": "acd89e59-0f24-4ee4-8582-c19f5ebbc56a",
"metadata": {},
"outputs": [],
"source": []
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style>\n",
".pl-dataframe > thead > tr > th {\n",
" text-align: right;\n",
"}\n",
"</style>\n",
"\n",
"<table border=\"1\" class=\"dataframe pl-dataframe\">\n",
"<small>shape: (4, 5)</small>\n",
"<thead>\n",
"<tr>\n",
"<th>\n",
"model\n",
"</th>\n",
"<th>\n",
"miscalibration\n",
"</th>\n",
"<th>\n",
"discrimination\n",
"</th>\n",
"<th>\n",
"uncertainty\n",
"</th>\n",
"<th>\n",
"score\n",
"</th>\n",
"</tr>\n",
"<tr>\n",
"<td>\n",
"str\n",
"</td>\n",
"<td>\n",
"f64\n",
"</td>\n",
"<td>\n",
"f64\n",
"</td>\n",
"<td>\n",
"f64\n",
"</td>\n",
"<td>\n",
"f64\n",
"</td>\n",
"</tr>\n",
"</thead>\n",
"<tbody>\n",
"<tr>\n",
"<td>\n",
"&quot;Trivial&quot;\n",
"</td>\n",
"<td>\n",
"0.000901\n",
"</td>\n",
"<td>\n",
"8.8818e-16\n",
"</td>\n",
"<td>\n",
"5.086002\n",
"</td>\n",
"<td>\n",
"5.086903\n",
"</td>\n",
"</tr>\n",
"<tr>\n",
"<td>\n",
"&quot;OLS&quot;\n",
"</td>\n",
"<td>\n",
"0.271994\n",
"</td>\n",
"<td>\n",
"1.605604\n",
"</td>\n",
"<td>\n",
"5.086002\n",
"</td>\n",
"<td>\n",
"3.752392\n",
"</td>\n",
"</tr>\n",
"<tr>\n",
"<td>\n",
"&quot;GLM_Poisson&quot;\n",
"</td>\n",
"<td>\n",
"0.470965\n",
"</td>\n",
"<td>\n",
"1.624191\n",
"</td>\n",
"<td>\n",
"5.086002\n",
"</td>\n",
"<td>\n",
"3.932776\n",
"</td>\n",
"</tr>\n",
"<tr>\n",
"<td>\n",
"&quot;HGBT_Poisson&quot;\n",
"</td>\n",
"<td>\n",
"0.088114\n",
"</td>\n",
"<td>\n",
"1.696757\n",
"</td>\n",
"<td>\n",
"5.086002\n",
"</td>\n",
"<td>\n",
"3.477359\n",
"</td>\n",
"</tr>\n",
"</tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
"shape: (4, 5)\n",
"┌──────────────┬────────────────┬────────────────┬─────────────┬──────────┐\n",
"│ model ┆ miscalibration ┆ discrimination ┆ uncertainty ┆ score │\n",
"│ --- ┆ --- ┆ --- ┆ --- ┆ --- │\n",
"│ str ┆ f64 ┆ f64 ┆ f64 ┆ f64 │\n",
"╞══════════════╪════════════════╪════════════════╪═════════════╪══════════╡\n",
"│ Trivial ┆ 0.000901 ┆ 8.8818e-16 ┆ 5.086002 ┆ 5.086903 │\n",
"│ OLS ┆ 0.271994 ┆ 1.605604 ┆ 5.086002 ┆ 3.752392 │\n",
"│ GLM_Poisson ┆ 0.470965 ┆ 1.624191 ┆ 5.086002 ┆ 3.932776 │\n",
"│ HGBT_Poisson ┆ 0.088114 ┆ 1.696757 ┆ 5.086002 ┆ 3.477359 │\n",
"└──────────────┴────────────────┴────────────────┴─────────────┴──────────┘"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import polars as pl\n",
"from model_diagnostics.scoring import GammaDeviance, decompose\n",
"\n",
"\n",
"gamma_deviance = GammaDeviance()\n",
"\n",
"df_list = []\n",
"for model in [\"Trivial\", \"OLS\", \"GLM_Poisson\", \"HGBT_Poisson\"]:\n",
" df = decompose(\n",
" scoring_function=gamma_deviance,\n",
" y_obs=y_test,\n",
" y_pred=df_pred_test[model],\n",
" )\n",
" df = df.with_columns(pl.lit(model).alias(\"model\"))\n",
" df_list.append(df)\n",
"df = pl.concat(df_list).select(\n",
" [\"model\", \"miscalibration\", \"discrimination\", \"uncertainty\", \"score\"]\n",
")\n",
"df"
]
},
{
"attachments": {},
"cell_type": "markdown",
"id": "8a34fbe9",
"metadata": {},
"source": [
"First, the HGBT model has the best (smallest) overall Gamma deviance (score).\n",
"We also see that the uncertainty term is model independent as it is a property of the data and the scoring function only.\n",
"The deciding difference of the scores stems from the miscalibration term.\n",
"Except for the trivial model, the HGBT model has the best (smallest) one.\n",
"The discrimination term (larger is better) is quite similar among the non-trivial models."
]
}
],
"metadata": {
Expand Down
10 changes: 9 additions & 1 deletion docs/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,15 @@

**Tools for diagnostics and assessment of (machine learning) models**

Read more in the [documentation](https://lorentzenchr.github.io/model-diagnostics/).
Highlights:

- Assess model calibration with identification functions (generalized residuals).
- Assess calibration graphically
- reliability diagrams for auto-calibration
- bias plots for conditional calibration
- Assess the predictive performance of models
- strictly consistent, homogeneous scoring functions
- score decomposition into miscalibration, discrimination and uncertainty

**Contributions**

Expand Down
13 changes: 12 additions & 1 deletion src/model_diagnostics/_utils/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,18 @@ def get_second_dimension(a: npt.ArrayLike, i: int) -> npt.ArrayLike:
def validate_2_arrays(
a: npt.ArrayLike, b: npt.ArrayLike
) -> tuple[np.ndarray, np.ndarray]:
"""Validate 2 arrays."""
"""Validate 2 arrays.
Both arrays are checked to have same dimensions and shapes and returned as numpy
arrays.
Returns
-------
a : ndarray
Input as an ndarray
b : ndarray
Input as an ndarray
"""
# Note: If the input is a pyarrow array, np.asarray produces a read-only ndarray.
a = np.asarray(a)
b = np.asarray(b)
Expand Down
2 changes: 2 additions & 0 deletions src/model_diagnostics/scoring/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
PinballLoss,
PoissonDeviance,
SquaredError,
decompose,
)

__all__ = [
Expand All @@ -16,4 +17,5 @@
"PinballLoss",
"PoissonDeviance",
"SquaredError",
"decompose",
]
Loading

0 comments on commit d54a1b8

Please sign in to comment.