Skip to content

Commit

Permalink
Adding top-class transformation option
Browse files Browse the repository at this point in the history
  • Loading branch information
jasonfan1997 committed Oct 5, 2024
1 parent 57f64c0 commit 0d4d8bd
Show file tree
Hide file tree
Showing 62 changed files with 1,317 additions and 325 deletions.
3 changes: 3 additions & 0 deletions GUI_cal_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,8 @@ def run_program():
args.append(str(save_plot))
if verbose_checkbox.value:
args.append("--verbose")
if topclass_checkbox.value:
args.append("--topclass")

command = ["python", "cal_metrics.py"] + args
print("Running command:", " ".join(command))
Expand Down Expand Up @@ -109,6 +111,7 @@ def clear_cache():
plot_bins_input = ui.number(label='Number of Bins for Reliability Diagram', value=10, min=2, step=1)
save_plot_input = ui.input(label='Save Plot to', placeholder='Enter file path').classes('w-full')
verbose_checkbox = ui.checkbox('Print Verbose Output', value=True)
topclass_checkbox = ui.checkbox('Transform to Top-class Problem', value=False)
ui.button('Run', on_click=run_program).classes('w-full')
ui.button('Clear Browser Cache', on_click=clear_cache).classes('w-full')

Expand Down
6 changes: 6 additions & 0 deletions cal_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,8 @@ def main():
help="Class to calculate metrics for (default: 1)")
parser.add_argument("--num_bins", type=int, default=10,
help="Number of bins for ECE/MCE/HL calculations (default: 10)")
parser.add_argument("--topclass", default=False, action="store_true",
help="Whether to transform the problem to top-class problem.")
parser.add_argument("--save_metrics", type=str,
help="Save the metrics to a csv file")
parser.add_argument("--plot", default=False, action="store_true",
Expand All @@ -184,6 +186,10 @@ def main():
# Load data from CSV
loader = data_loader(args.csv_file)

# if transofrm it to top-class problem
if args.topclass:
loader = loader.transform_topclass()

# Perform calculations
if not loader.have_subgroup:
perform_calculation(probs=loader.probs, labels=loader.labels, args=args, suffix="")
Expand Down
16 changes: 16 additions & 0 deletions calzone/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import matplotlib.pyplot as plt
from scipy.optimize import minimize_scalar
from scipy.special import softmax
import copy

def make_roc_curve(y_true, y_proba, class_to_plot=None):
"""
Expand Down Expand Up @@ -276,6 +277,7 @@ class data_loader():
Methods:
__init__(self, data_path): Initializes the data_loader object and loads data from a CSV file.
transform_topclass(self): Transforms the data to top class binary problem.
"""

def __init__(self, data_path):
Expand Down Expand Up @@ -325,6 +327,20 @@ def __init__(self, data_path):
for j, subgroup_class in enumerate(self.subgroups_class[i]):
indices.append(np.where(self.data[:, self.subgroup_indices[i]] == subgroup_class)[0])
self.subgroups_index.append(indices)

def transform_topclass(self):
"""
Transforms the data to top class binary problem
Returns:
data_loader: A new data_loader object with transformed data
"""
new_loader = copy.deepcopy(self)
top_class = np.argmax(self.probs, axis=1)
new_loader.probs = np.column_stack((1 - np.max(self.probs, axis=1), np.max(self.probs, axis=1)))
new_loader.labels = (self.labels.flatten() == top_class).astype(int).reshape(-1, 1)
new_loader.data = np.column_stack((new_loader.probs, new_loader.labels))
return new_loader

class fake_binary_data_generator():
"""A class for generating fake binary data and applying miscalibration.
Expand Down
Binary file modified docs/build/doctrees/calzone.doctree
Binary file not shown.
Binary file modified docs/build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/build/doctrees/index.doctree
Binary file not shown.
2 changes: 1 addition & 1 deletion docs/build/doctrees/nbsphinx/notebooks/GUI.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To run GUI, you need to install nicegui using the following command:"
"GUI is currently under development and not all features from the command line are supported. To run GUI, you need to install nicegui using the following command:"
]
},
{
Expand Down
8 changes: 5 additions & 3 deletions docs/build/doctrees/nbsphinx/notebooks/quickstart.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -183,8 +183,9 @@
" [--prevalence_adjustment] [--n_bootstrap N_BOOTSTRAP]\n",
" [--bootstrap_ci BOOTSTRAP_CI]\n",
" [--class_to_calculate CLASS_TO_CALCULATE]\n",
" [--num_bins NUM_BINS] [--save_metrics SAVE_METRICS]\n",
" [--plot] [--plot_bins PLOT_BINS] [--save_plot SAVE_PLOT]\n",
" [--num_bins NUM_BINS] [--topclass]\n",
" [--save_metrics SAVE_METRICS] [--plot]\n",
" [--plot_bins PLOT_BINS] [--save_plot SAVE_PLOT]\n",
" [--save_diagram_output SAVE_DIAGRAM_OUTPUT] [--verbose]\n",
"\n",
"Calculate calibration metrics and visualize reliability diagram.\n",
Expand All @@ -209,6 +210,7 @@
" Class to calculate metrics for (default: 1)\n",
" --num_bins NUM_BINS Number of bins for ECE/MCE/HL calculations (default:\n",
" 10)\n",
" --topclass Whether to transform the problem to top-class problem.\n",
" --save_metrics SAVE_METRICS\n",
" Save the metrics to a csv file\n",
" --plot Plot reliability diagram (default: False)\n",
Expand Down
121 changes: 121 additions & 0 deletions docs/build/doctrees/nbsphinx/notebooks/topclass.ipynb

Large diffs are not rendered by default.

Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/build/doctrees/notebooks/GUI.doctree
Binary file not shown.
Binary file modified docs/build/doctrees/notebooks/quickstart.doctree
Binary file not shown.
Binary file added docs/build/doctrees/notebooks/topclass.doctree
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
1 change: 1 addition & 0 deletions docs/build/html/_modules/calzone/metrics.html
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/spiegelhalter_z.html">Spiegelhalter’s Z-test</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/prevalence_adjustment.html">Prevalence adjustment</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/subgroup.html">Subgroup analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/topclass.html">Multiclass extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/GUI.html">Running GUI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules.html">calzone</a></li>
</ul>
Expand Down
20 changes: 20 additions & 0 deletions docs/build/html/_modules/calzone/utils.html
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/spiegelhalter_z.html">Spiegelhalter’s Z-test</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/prevalence_adjustment.html">Prevalence adjustment</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/subgroup.html">Subgroup analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/topclass.html">Multiclass extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/GUI.html">Running GUI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules.html">calzone</a></li>
</ul>
Expand Down Expand Up @@ -100,6 +101,7 @@ <h1>Source code for calzone.utils</h1><div class="highlight"><pre>
<span class="kn">import</span> <span class="nn">matplotlib.pyplot</span> <span class="k">as</span> <span class="nn">plt</span>
<span class="kn">from</span> <span class="nn">scipy.optimize</span> <span class="kn">import</span> <span class="n">minimize_scalar</span>
<span class="kn">from</span> <span class="nn">scipy.special</span> <span class="kn">import</span> <span class="n">softmax</span>
<span class="kn">import</span> <span class="nn">copy</span>

<div class="viewcode-block" id="make_roc_curve">
<a class="viewcode-back" href="../../calzone.html#calzone.utils.make_roc_curve">[docs]</a>
Expand Down Expand Up @@ -385,6 +387,7 @@ <h1>Source code for calzone.utils</h1><div class="highlight"><pre>

<span class="sd"> Methods:</span>
<span class="sd"> __init__(self, data_path): Initializes the data_loader object and loads data from a CSV file.</span>
<span class="sd"> transform_topclass(self): Transforms the data to top class binary problem.</span>
<span class="sd"> &quot;&quot;&quot;</span>

<div class="viewcode-block" id="data_loader.__init__">
Expand Down Expand Up @@ -436,6 +439,23 @@ <h1>Source code for calzone.utils</h1><div class="highlight"><pre>
<span class="k">for</span> <span class="n">j</span><span class="p">,</span> <span class="n">subgroup_class</span> <span class="ow">in</span> <span class="nb">enumerate</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">subgroups_class</span><span class="p">[</span><span class="n">i</span><span class="p">]):</span>
<span class="n">indices</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">np</span><span class="o">.</span><span class="n">where</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">data</span><span class="p">[:,</span> <span class="bp">self</span><span class="o">.</span><span class="n">subgroup_indices</span><span class="p">[</span><span class="n">i</span><span class="p">]]</span> <span class="o">==</span> <span class="n">subgroup_class</span><span class="p">)[</span><span class="mi">0</span><span class="p">])</span>
<span class="bp">self</span><span class="o">.</span><span class="n">subgroups_index</span><span class="o">.</span><span class="n">append</span><span class="p">(</span><span class="n">indices</span><span class="p">)</span></div>


<div class="viewcode-block" id="data_loader.transform_topclass">
<a class="viewcode-back" href="../../calzone.html#calzone.utils.data_loader.transform_topclass">[docs]</a>
<span class="k">def</span> <span class="nf">transform_topclass</span><span class="p">(</span><span class="bp">self</span><span class="p">):</span>
<span class="w"> </span><span class="sd">&quot;&quot;&quot;</span>
<span class="sd"> Transforms the data to top class binary problem</span>

<span class="sd"> Returns:</span>
<span class="sd"> data_loader: A new data_loader object with transformed data</span>
<span class="sd"> &quot;&quot;&quot;</span>
<span class="n">new_loader</span> <span class="o">=</span> <span class="n">copy</span><span class="o">.</span><span class="n">deepcopy</span><span class="p">(</span><span class="bp">self</span><span class="p">)</span>
<span class="n">top_class</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">argmax</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">probs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)</span>
<span class="n">new_loader</span><span class="o">.</span><span class="n">probs</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">column_stack</span><span class="p">((</span><span class="mi">1</span> <span class="o">-</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">probs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">),</span> <span class="n">np</span><span class="o">.</span><span class="n">max</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">probs</span><span class="p">,</span> <span class="n">axis</span><span class="o">=</span><span class="mi">1</span><span class="p">)))</span>
<span class="n">new_loader</span><span class="o">.</span><span class="n">labels</span> <span class="o">=</span> <span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">labels</span><span class="o">.</span><span class="n">flatten</span><span class="p">()</span> <span class="o">==</span> <span class="n">top_class</span><span class="p">)</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="nb">int</span><span class="p">)</span><span class="o">.</span><span class="n">reshape</span><span class="p">(</span><span class="o">-</span><span class="mi">1</span><span class="p">,</span> <span class="mi">1</span><span class="p">)</span>
<span class="n">new_loader</span><span class="o">.</span><span class="n">data</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">column_stack</span><span class="p">((</span><span class="n">new_loader</span><span class="o">.</span><span class="n">probs</span><span class="p">,</span> <span class="n">new_loader</span><span class="o">.</span><span class="n">labels</span><span class="p">))</span>
<span class="k">return</span> <span class="n">new_loader</span></div>
</div>


Expand Down
1 change: 1 addition & 0 deletions docs/build/html/_modules/calzone/vis.html
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/spiegelhalter_z.html">Spiegelhalter’s Z-test</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/prevalence_adjustment.html">Prevalence adjustment</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/subgroup.html">Subgroup analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/topclass.html">Multiclass extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../notebooks/GUI.html">Running GUI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../../modules.html">calzone</a></li>
</ul>
Expand Down
1 change: 1 addition & 0 deletions docs/build/html/_modules/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
<li class="toctree-l1"><a class="reference internal" href="../notebooks/spiegelhalter_z.html">Spiegelhalter’s Z-test</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notebooks/prevalence_adjustment.html">Prevalence adjustment</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notebooks/subgroup.html">Subgroup analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notebooks/topclass.html">Multiclass extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="../notebooks/GUI.html">Running GUI</a></li>
<li class="toctree-l1"><a class="reference internal" href="../modules.html">calzone</a></li>
</ul>
Expand Down
2 changes: 2 additions & 0 deletions docs/build/html/_sources/index.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ Key features of calzone include:
* Bootstrapping capabilities for confidence interval estimation
* Subgroup analysis for calibration metrics
* Provides command line interface scripts for batch processing
* Multi-class extension by 1-vs-rest or top-class only

To accurately assess the calibration of machine learning models, it is essential to have a comprehensive and reprensative dataset with sufficient coverage of the prediction space. The calibration metrics is not meaningful if the dataset is not representative of true intended population.

Expand All @@ -39,5 +40,6 @@ We hope you find calzone useful in your machine learning projects!
notebooks/spiegelhalter_z.ipynb
notebooks/prevalence_adjustment.ipynb
notebooks/subgroup.ipynb
notebooks/topclass.ipynb
notebooks/GUI.ipynb
modules
2 changes: 1 addition & 1 deletion docs/build/html/_sources/notebooks/GUI.ipynb.txt
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"To run GUI, you need to install nicegui using the following command:"
"GUI is currently under development and not all features from the command line are supported. To run GUI, you need to install nicegui using the following command:"
]
},
{
Expand Down
8 changes: 5 additions & 3 deletions docs/build/html/_sources/notebooks/quickstart.ipynb.txt
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@
},
{
"cell_type": "code",
"execution_count": 7,
"execution_count": 1,
"metadata": {},
"outputs": [
{
Expand All @@ -183,8 +183,9 @@
" [--prevalence_adjustment] [--n_bootstrap N_BOOTSTRAP]\n",
" [--bootstrap_ci BOOTSTRAP_CI]\n",
" [--class_to_calculate CLASS_TO_CALCULATE]\n",
" [--num_bins NUM_BINS] [--save_metrics SAVE_METRICS]\n",
" [--plot] [--plot_bins PLOT_BINS] [--save_plot SAVE_PLOT]\n",
" [--num_bins NUM_BINS] [--topclass]\n",
" [--save_metrics SAVE_METRICS] [--plot]\n",
" [--plot_bins PLOT_BINS] [--save_plot SAVE_PLOT]\n",
" [--save_diagram_output SAVE_DIAGRAM_OUTPUT] [--verbose]\n",
"\n",
"Calculate calibration metrics and visualize reliability diagram.\n",
Expand All @@ -209,6 +210,7 @@
" Class to calculate metrics for (default: 1)\n",
" --num_bins NUM_BINS Number of bins for ECE/MCE/HL calculations (default:\n",
" 10)\n",
" --topclass Whether to transform the problem to top-class problem.\n",
" --save_metrics SAVE_METRICS\n",
" Save the metrics to a csv file\n",
" --plot Plot reliability diagram (default: False)\n",
Expand Down
121 changes: 121 additions & 0 deletions docs/build/html/_sources/notebooks/topclass.ipynb.txt

Large diffs are not rendered by default.

18 changes: 18 additions & 0 deletions docs/build/html/calzone.html
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
<li class="toctree-l1"><a class="reference internal" href="notebooks/spiegelhalter_z.html">Spiegelhalter’s Z-test</a></li>
<li class="toctree-l1"><a class="reference internal" href="notebooks/prevalence_adjustment.html">Prevalence adjustment</a></li>
<li class="toctree-l1"><a class="reference internal" href="notebooks/subgroup.html">Subgroup analysis</a></li>
<li class="toctree-l1"><a class="reference internal" href="notebooks/topclass.html">Multiclass extension</a></li>
<li class="toctree-l1"><a class="reference internal" href="notebooks/GUI.html">Running GUI</a></li>
<li class="toctree-l1 current"><a class="reference internal" href="modules.html">calzone</a><ul class="current">
<li class="toctree-l2 current"><a class="current reference internal" href="#">calzone package</a><ul>
Expand Down Expand Up @@ -652,6 +653,12 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Link to this headi
<dd><p>Initializes the data_loader object and loads data from a CSV file.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="calzone.utils.data_loader.transform_topclass">
<span class="sig-name descname"><span class="pre">transform_topclass</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">self</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/calzone/utils.html#data_loader.transform_topclass"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#calzone.utils.data_loader.transform_topclass" title="Link to this definition"></a></dt>
<dd><p>Transforms the data to top class binary problem.</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="id0">
<span class="sig-name descname"><span class="pre">__init__</span></span><span class="sig-paren">(</span><em class="sig-param"><span class="n"><span class="pre">data_path</span></span></em><span class="sig-paren">)</span><a class="reference internal" href="_modules/calzone/utils.html#data_loader.__init__"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#id0" title="Link to this definition"></a></dt>
Expand All @@ -672,6 +679,17 @@ <h2>Submodules<a class="headerlink" href="#submodules" title="Link to this headi
- If there is no header, the columns must be in the order of proba_0,proba_1,…,label</p>
</dd></dl>

<dl class="py method">
<dt class="sig sig-object py" id="id1">
<span class="sig-name descname"><span class="pre">transform_topclass</span></span><span class="sig-paren">(</span><span class="sig-paren">)</span><a class="reference internal" href="_modules/calzone/utils.html#data_loader.transform_topclass"><span class="viewcode-link"><span class="pre">[source]</span></span></a><a class="headerlink" href="#id1" title="Link to this definition"></a></dt>
<dd><p>Transforms the data to top class binary problem</p>
<dl class="field-list simple">
<dt class="field-odd">Returns<span class="colon">:</span></dt>
<dd class="field-odd"><p><strong>data_loader</strong> – A new data_loader object with transformed data</p>
</dd>
</dl>
</dd></dl>

</dd></dl>

<dl class="py class">
Expand Down
Loading

0 comments on commit 0d4d8bd

Please sign in to comment.