Skip to content

Commit

Permalink
Parameterize binarization function for tradeoff metrics (#74)
Browse files Browse the repository at this point in the history
* Replace evaluation_metrics_row with refactored_evaluation_metrics_row

* Move metrics to correct file

* remove tests

* update docs

* unlint

* fix missing nothing

* no defaults

* rely on default

* add tests

* throw out const func

* helpers

* fix

* add docstring

* update docs

* go go gadget julia 1.6 slash

* docs giveth and they taketh away
  • Loading branch information
hannahilea authored May 16, 2022
1 parent 4a6390a commit e4c22c4
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 29 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "Lighthouse"
uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59"
authors = ["Beacon Biosignals, Inc."]
version = "0.14.7"
version = "0.14.8"

[deps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
Expand Down
3 changes: 2 additions & 1 deletion src/Lighthouse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,8 @@ include("metrics.jl")
export confusion_matrix, accuracy, binary_statistics, cohens_kappa, calibration_curve,
get_tradeoff_metrics, get_tradeoff_metrics_binary_multirater, get_hardened_metrics,
get_hardened_metrics_multirater, get_hardened_metrics_multiclass,
get_label_metrics_multirater, get_label_metrics_multirater_multiclass
get_label_metrics_multirater, get_label_metrics_multirater_multiclass,
harden_by_threshold

include("classifier.jl")
export AbstractClassifier
Expand Down
63 changes: 39 additions & 24 deletions src/metrics.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,9 @@
const BINARIZE_NOTE = string("Supply a function to the keyword argument `binarize` ",
"which takes as input `(soft_label, threshold)` and ",
"outputs a `Bool` indicating whether or not the class of interest")

binarize_by_threshold(soft, threshold) = soft >= threshold

#####
##### confusion matrices
#####
Expand Down Expand Up @@ -190,17 +196,18 @@ end

"""
get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index;
thresholds)
thresholds, binarize=binarize_by_threshold)
Return [`TradeoffMetricsRow`] calculated for the given `class_index`, with the following
fields guaranteed to be non-missing: `roc_curve`, `roc_auc`, pr_curve`,
`reliability_calibration_curve`, `reliability_calibration_score`.`
`reliability_calibration_curve`, `reliability_calibration_score`.` $(BINARIZE_NOTE)
(`class_index`).
"""
function get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index;
thresholds)
thresholds, binarize=binarize_by_threshold)
stats = per_threshold_confusion_statistics(predicted_soft_labels,
elected_hard_labels, thresholds,
class_index)
class_index; binarize)
roc_curve = (map(t -> t.false_positive_rate, stats),
map(t -> t.true_positive_rate, stats))
pr_curve = (map(t -> t.true_positive_rate, stats),
Expand All @@ -221,16 +228,17 @@ end

"""
get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, class_index;
thresholds)
thresholds, binarize=binarize_by_threshold)
Return [`TradeoffMetricsRow`] calculated for the given `class_index`. In addition
to metrics calculated by [`get_tradeoff_metrics`](@ref), additionally calculates
`spearman_correlation`-based metrics.
`spearman_correlation`-based metrics. $(BINARIZE_NOTE) (`class_index`).
"""
function get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels,
votes, class_index; thresholds)
votes, class_index; thresholds,
binarize=binarize_by_threshold)
basic_row = get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels,
class_index; thresholds)
class_index; thresholds, binarize)
corr = _calculate_spearman_correlation(predicted_soft_labels, votes)
row = Tables.rowmerge(basic_row,
(;
Expand Down Expand Up @@ -392,18 +400,19 @@ Where...
Alternatively, an `observation_table` that consists of rows of type [`ObservationRow`](@ref)
can be passed in in place of `predicted_soft_labels`,`predicted_hard_labels`,`elected_hard_labels`,
and `votes`.
and `votes`. $(BINARIZE_NOTE).
See also [`evaluation_metrics_plot`](@ref).
"""
function evaluation_metrics_row(observation_table, classes, thresholds=0.0:0.01:1.0;
strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing,
optimal_threshold_class::Union{Missing,Nothing,Integer}=missing)
optimal_threshold_class::Union{Missing,Nothing,Integer}=missing,
binarize=binarize_by_threshold)
inputs = _observation_table_to_inputs(observation_table)
return evaluation_metrics_row(inputs.predicted_hard_labels,
inputs.predicted_soft_labels, inputs.elected_hard_labels,
classes, thresholds; inputs.votes, strata,
optimal_threshold_class)
optimal_threshold_class, binarize)
end

function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
Expand All @@ -414,7 +423,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
strata::Union{Nothing,
AbstractVector{Set{T}} where T}=nothing,
optimal_threshold_class::Union{Missing,Nothing,
Integer}=missing)
Integer}=missing,
binarize=binarize_by_threshold)
class_labels = string.(collect(classes)) # Plots.jl expects this to be an `AbstractVector`
class_indices = 1:length(classes)

Expand All @@ -425,12 +435,12 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
map(ic -> get_tradeoff_metrics_binary_multirater(predicted_soft_labels,
elected_hard_labels, votes,
ic;
thresholds),
thresholds, binarize),
class_indices)
else
map(ic -> get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels,
ic;
thresholds),
thresholds, binarize),
class_indices)
end

Expand All @@ -440,7 +450,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
cal = _calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels,
votes;
thresholds,
class_of_interest_index=optimal_threshold_class)
class_of_interest_index=optimal_threshold_class,
binarize)
optimal_threshold = cal.threshold
elseif has_value(optimal_threshold_class)
roc_curve = tradeoff_metrics_rows[findfirst(==(optimal_threshold_class),
Expand All @@ -457,7 +468,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector,
if !ismissing(optimal_threshold)
other_class = optimal_threshold_class == 1 ? 2 : 1
for (i, row) in enumerate(eachrow(predicted_soft_labels))
predicted_hard_labels[i] = row[optimal_threshold_class] .>= optimal_threshold ?
predicted_hard_labels[i] = binarize(row[optimal_threshold_class],
optimal_threshold) ?
optimal_threshold_class : other_class
end
end
Expand Down Expand Up @@ -607,7 +619,7 @@ function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table,
per_expert_discrimination_calibration_scores,

# from kwargs:
optimal_threshold_class = _values_or_missing(optimal_threshold_class),
optimal_threshold_class=_values_or_missing(optimal_threshold_class),
class_labels, thresholds, optimal_threshold, stratified_kappas)
end

Expand Down Expand Up @@ -800,13 +812,14 @@ end

function _calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels,
votes; thresholds,
class_of_interest_index)
class_of_interest_index,
binarize=binarize_by_threshold)
elected_probabilities = _elected_probabilities(votes, class_of_interest_index)
bin_count = min(size(votes, 2) + 1, 10)
per_threshold_curves = map(thresholds) do thresh
pred_soft = view(predicted_soft_labels, :, class_of_interest_index)
return calibration_curve(elected_probabilities,
predicted_soft_labels[:, class_of_interest_index] .>=
thresh; bin_count=bin_count)
binarize.(pred_soft, thresh); bin_count=bin_count)
end
i_min = argmin([c.mean_squared_error for c in per_threshold_curves])
curve = per_threshold_curves[i_min]
Expand Down Expand Up @@ -882,24 +895,26 @@ function _validate_threshold_class(optimal_threshold_class, classes)
end

function per_class_confusion_statistics(predicted_soft_labels::AbstractMatrix,
elected_hard_labels::AbstractVector, thresholds)
elected_hard_labels::AbstractVector, thresholds;
binarize=binarize_by_threshold)
class_count = size(predicted_soft_labels, 2)
return map(1:class_count) do i
return per_threshold_confusion_statistics(predicted_soft_labels,
elected_hard_labels,
thresholds, i)
thresholds, i; binarize)
end
end

function per_threshold_confusion_statistics(predicted_soft_labels::AbstractMatrix,
elected_hard_labels::AbstractVector, thresholds,
class_index)
class_index; binarize=binarize_by_threshold)
confusions = [confusion_matrix(2) for _ in 1:length(thresholds)]
for label_index in 1:length(elected_hard_labels)
predicted_soft_label = predicted_soft_labels[label_index, class_index]
elected = (elected_hard_labels[label_index] == class_index) + 1
for (threshold_index, threshold) in enumerate(thresholds)
predicted = (predicted_soft_label >= threshold) + 1
# Convert from binarized output to 2-class labels (1, 2)
predicted = binarize(predicted_soft_label, threshold) + 1
confusions[threshold_index][predicted, elected] += 1
end
end
Expand Down
12 changes: 11 additions & 1 deletion src/row.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,16 @@ const ObservationRow = Legolas.@row("lighthouse.observation@1",
elected_hard_label::Int64,
votes::Union{Missing,Vector{Int64}})

# Convert vector of per-class soft label vectors to expected matrix format, e.g.,
# [[0.1, .2, .7], [0.8, .1, .1]] for 2 observations of 3-class classification returns
# ```
# [0.1 0.2 0.7;
# 0.8 0.1 0.1]
# ```
function _predicted_soft_to_matrix(per_observation_soft_labels)
return transpose(reduce(hcat, per_observation_soft_labels))
end

function _observation_table_to_inputs(observation_table)
Legolas.validate(observation_table, OBSERVATION_ROW_SCHEMA)
df_table = Tables.columns(observation_table)
Expand All @@ -169,7 +179,7 @@ function _observation_table_to_inputs(observation_table)
votes = any(ismissing, df_table.votes) ? missing :
transpose(reduce(hcat, df_table.votes))

predicted_soft_labels = transpose(reduce(hcat, df_table.predicted_soft_labels))
predicted_soft_labels = _predicted_soft_to_matrix(df_table.predicted_soft_labels)
return (; predicted_hard_labels=df_table.predicted_hard_label, predicted_soft_labels,
elected_hard_labels=df_table.elected_hard_label, votes)
end
Expand Down
63 changes: 63 additions & 0 deletions test/metrics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,66 @@ end
@test all(iszero, totals)
@test isnan(mean_squared_error)
end

@testset "`calibration_curve`" begin
@test binarize_by_threshold(0.2, 0.8) == false
@test binarize_by_threshold(0.2, 0.2) == true
@test binarize_by_threshold(0.3, 0.2) == true
@test binarize_by_threshold.([0, 0, 0], 0.2) == [0, 0, 0]
end

@testset "Metrics hardening/binarization" begin
predicted_soft_labels = [0.51 0.49
0.49 0.51
0.1 0.9
0.9 0.1
0.0 1.0]
elected_hard_labels = [1, 2, 2, 2, 1]
thresholds = [0.25, 0.5, 0.75]
i_class = 2
default_metrics = get_tradeoff_metrics(predicted_soft_labels,
elected_hard_labels,
i_class; thresholds)

# Use bogus threshold/hardening function to prove that hardening function is
# used internally
scaled_binarize_by_threshold = (soft, threshold) -> soft >= threshold / 10
scaled_thresholds = 10 .* thresholds
scaled_metrics = get_tradeoff_metrics(predicted_soft_labels,
elected_hard_labels,
i_class; thresholds=scaled_thresholds,
binarize=scaled_binarize_by_threshold)
@test isequal(default_metrics, scaled_metrics)

# Discrim calibration
votes = [1 1 1
0 2 2
1 2 2
1 1 2
0 1 1]
cal = Lighthouse._calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels,
votes;
thresholds,
class_of_interest_index=i_class)

scaled_cal = Lighthouse._calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels,
votes;
class_of_interest_index=i_class,
thresholds=scaled_thresholds,
binarize=scaled_binarize_by_threshold)
for k in keys(cal)
if k == :threshold
@test cal[k] * 10 == scaled_cal[k] # Should be the same _relative_ threshold
else
@test isequal(cal[k], scaled_cal[k])
end
end

conf = Lighthouse.per_class_confusion_statistics(predicted_soft_labels,
elected_hard_labels, thresholds)
scaled_conf = Lighthouse.per_class_confusion_statistics(predicted_soft_labels,
elected_hard_labels,
scaled_thresholds;
binarize=scaled_binarize_by_threshold)
@test isequal(conf, scaled_conf)
end
4 changes: 2 additions & 2 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ using StableRNGs
using Lighthouse
using Lighthouse: plot_reliability_calibration_curves, plot_pr_curves,
plot_roc_curves, plot_kappas, plot_confusion_matrix,
evaluation_metrics_plot, evaluation_metrics
evaluation_metrics_plot, evaluation_metrics, binarize_by_threshold
using Base.Threads
using CairoMakie
using Legolas, Tables
Expand All @@ -13,7 +13,7 @@ using Arrow
# Needs to be set for figures
# returning true for showable("image/png", obj)
# which TensorBoardLogger.jl uses to determine output
CairoMakie.activate!(type="png")
CairoMakie.activate!(; type="png")
plot_results = joinpath(@__DIR__, "plot_results")
# Remove any old plots
isdir(plot_results) && rm(plot_results; force=true, recursive=true)
Expand Down

2 comments on commit e4c22c4

@hannahilea
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/60370

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.14.8 -m "<description of version>" e4c22c42eb991ff895ad3e58fd2e0096b45627fd
git push origin v0.14.8

Please sign in to comment.