diff --git a/Project.toml b/Project.toml index e6c8b17..61c0464 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lighthouse" uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59" authors = ["Beacon Biosignals, Inc."] -version = "0.14.7" +version = "0.14.8" [deps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" diff --git a/src/Lighthouse.jl b/src/Lighthouse.jl index 25390d1..db5dcdd 100644 --- a/src/Lighthouse.jl +++ b/src/Lighthouse.jl @@ -24,7 +24,8 @@ include("metrics.jl") export confusion_matrix, accuracy, binary_statistics, cohens_kappa, calibration_curve, get_tradeoff_metrics, get_tradeoff_metrics_binary_multirater, get_hardened_metrics, get_hardened_metrics_multirater, get_hardened_metrics_multiclass, - get_label_metrics_multirater, get_label_metrics_multirater_multiclass + get_label_metrics_multirater, get_label_metrics_multirater_multiclass, + harden_by_threshold include("classifier.jl") export AbstractClassifier diff --git a/src/metrics.jl b/src/metrics.jl index 6e71f67..63e1c20 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -1,3 +1,9 @@ +const BINARIZE_NOTE = string("Supply a function to the keyword argument `binarize` ", + "which takes as input `(soft_label, threshold)` and ", + "outputs a `Bool` indicating whether or not the class of interest") + +binarize_by_threshold(soft, threshold) = soft >= threshold + ##### ##### confusion matrices ##### @@ -190,17 +196,18 @@ end """ get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index; - thresholds) + thresholds, binarize=binarize_by_threshold) Return [`TradeoffMetricsRow`] calculated for the given `class_index`, with the following fields guaranteed to be non-missing: `roc_curve`, `roc_auc`, pr_curve`, -`reliability_calibration_curve`, `reliability_calibration_score`.` +`reliability_calibration_curve`, `reliability_calibration_score`.` $(BINARIZE_NOTE) +(`class_index`). """ function get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index; - thresholds) + thresholds, binarize=binarize_by_threshold) stats = per_threshold_confusion_statistics(predicted_soft_labels, elected_hard_labels, thresholds, - class_index) + class_index; binarize) roc_curve = (map(t -> t.false_positive_rate, stats), map(t -> t.true_positive_rate, stats)) pr_curve = (map(t -> t.true_positive_rate, stats), @@ -221,16 +228,17 @@ end """ get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, class_index; - thresholds) + thresholds, binarize=binarize_by_threshold) Return [`TradeoffMetricsRow`] calculated for the given `class_index`. In addition to metrics calculated by [`get_tradeoff_metrics`](@ref), additionally calculates -`spearman_correlation`-based metrics. +`spearman_correlation`-based metrics. $(BINARIZE_NOTE) (`class_index`). """ function get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, - votes, class_index; thresholds) + votes, class_index; thresholds, + binarize=binarize_by_threshold) basic_row = get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, - class_index; thresholds) + class_index; thresholds, binarize) corr = _calculate_spearman_correlation(predicted_soft_labels, votes) row = Tables.rowmerge(basic_row, (; @@ -392,18 +400,19 @@ Where... Alternatively, an `observation_table` that consists of rows of type [`ObservationRow`](@ref) can be passed in in place of `predicted_soft_labels`,`predicted_hard_labels`,`elected_hard_labels`, -and `votes`. +and `votes`. $(BINARIZE_NOTE). See also [`evaluation_metrics_plot`](@ref). """ function evaluation_metrics_row(observation_table, classes, thresholds=0.0:0.01:1.0; strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Missing,Nothing,Integer}=missing) + optimal_threshold_class::Union{Missing,Nothing,Integer}=missing, + binarize=binarize_by_threshold) inputs = _observation_table_to_inputs(observation_table) return evaluation_metrics_row(inputs.predicted_hard_labels, inputs.predicted_soft_labels, inputs.elected_hard_labels, classes, thresholds; inputs.votes, strata, - optimal_threshold_class) + optimal_threshold_class, binarize) end function evaluation_metrics_row(predicted_hard_labels::AbstractVector, @@ -414,7 +423,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, strata::Union{Nothing, AbstractVector{Set{T}} where T}=nothing, optimal_threshold_class::Union{Missing,Nothing, - Integer}=missing) + Integer}=missing, + binarize=binarize_by_threshold) class_labels = string.(collect(classes)) # Plots.jl expects this to be an `AbstractVector` class_indices = 1:length(classes) @@ -425,12 +435,12 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, map(ic -> get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, votes, ic; - thresholds), + thresholds, binarize), class_indices) else map(ic -> get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, ic; - thresholds), + thresholds, binarize), class_indices) end @@ -440,7 +450,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, cal = _calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels, votes; thresholds, - class_of_interest_index=optimal_threshold_class) + class_of_interest_index=optimal_threshold_class, + binarize) optimal_threshold = cal.threshold elseif has_value(optimal_threshold_class) roc_curve = tradeoff_metrics_rows[findfirst(==(optimal_threshold_class), @@ -457,7 +468,8 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, if !ismissing(optimal_threshold) other_class = optimal_threshold_class == 1 ? 2 : 1 for (i, row) in enumerate(eachrow(predicted_soft_labels)) - predicted_hard_labels[i] = row[optimal_threshold_class] .>= optimal_threshold ? + predicted_hard_labels[i] = binarize(row[optimal_threshold_class], + optimal_threshold) ? optimal_threshold_class : other_class end end @@ -607,7 +619,7 @@ function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table, per_expert_discrimination_calibration_scores, # from kwargs: - optimal_threshold_class = _values_or_missing(optimal_threshold_class), + optimal_threshold_class=_values_or_missing(optimal_threshold_class), class_labels, thresholds, optimal_threshold, stratified_kappas) end @@ -800,13 +812,14 @@ end function _calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels, votes; thresholds, - class_of_interest_index) + class_of_interest_index, + binarize=binarize_by_threshold) elected_probabilities = _elected_probabilities(votes, class_of_interest_index) bin_count = min(size(votes, 2) + 1, 10) per_threshold_curves = map(thresholds) do thresh + pred_soft = view(predicted_soft_labels, :, class_of_interest_index) return calibration_curve(elected_probabilities, - predicted_soft_labels[:, class_of_interest_index] .>= - thresh; bin_count=bin_count) + binarize.(pred_soft, thresh); bin_count=bin_count) end i_min = argmin([c.mean_squared_error for c in per_threshold_curves]) curve = per_threshold_curves[i_min] @@ -882,24 +895,26 @@ function _validate_threshold_class(optimal_threshold_class, classes) end function per_class_confusion_statistics(predicted_soft_labels::AbstractMatrix, - elected_hard_labels::AbstractVector, thresholds) + elected_hard_labels::AbstractVector, thresholds; + binarize=binarize_by_threshold) class_count = size(predicted_soft_labels, 2) return map(1:class_count) do i return per_threshold_confusion_statistics(predicted_soft_labels, elected_hard_labels, - thresholds, i) + thresholds, i; binarize) end end function per_threshold_confusion_statistics(predicted_soft_labels::AbstractMatrix, elected_hard_labels::AbstractVector, thresholds, - class_index) + class_index; binarize=binarize_by_threshold) confusions = [confusion_matrix(2) for _ in 1:length(thresholds)] for label_index in 1:length(elected_hard_labels) predicted_soft_label = predicted_soft_labels[label_index, class_index] elected = (elected_hard_labels[label_index] == class_index) + 1 for (threshold_index, threshold) in enumerate(thresholds) - predicted = (predicted_soft_label >= threshold) + 1 + # Convert from binarized output to 2-class labels (1, 2) + predicted = binarize(predicted_soft_label, threshold) + 1 confusions[threshold_index][predicted, elected] += 1 end end diff --git a/src/row.jl b/src/row.jl index 58a29ca..85e100c 100644 --- a/src/row.jl +++ b/src/row.jl @@ -159,6 +159,16 @@ const ObservationRow = Legolas.@row("lighthouse.observation@1", elected_hard_label::Int64, votes::Union{Missing,Vector{Int64}}) +# Convert vector of per-class soft label vectors to expected matrix format, e.g., +# [[0.1, .2, .7], [0.8, .1, .1]] for 2 observations of 3-class classification returns +# ``` +# [0.1 0.2 0.7; +# 0.8 0.1 0.1] +# ``` +function _predicted_soft_to_matrix(per_observation_soft_labels) + return transpose(reduce(hcat, per_observation_soft_labels)) +end + function _observation_table_to_inputs(observation_table) Legolas.validate(observation_table, OBSERVATION_ROW_SCHEMA) df_table = Tables.columns(observation_table) @@ -169,7 +179,7 @@ function _observation_table_to_inputs(observation_table) votes = any(ismissing, df_table.votes) ? missing : transpose(reduce(hcat, df_table.votes)) - predicted_soft_labels = transpose(reduce(hcat, df_table.predicted_soft_labels)) + predicted_soft_labels = _predicted_soft_to_matrix(df_table.predicted_soft_labels) return (; predicted_hard_labels=df_table.predicted_hard_label, predicted_soft_labels, elected_hard_labels=df_table.elected_hard_label, votes) end diff --git a/test/metrics.jl b/test/metrics.jl index 6f72ff7..68ffe0a 100644 --- a/test/metrics.jl +++ b/test/metrics.jl @@ -156,3 +156,66 @@ end @test all(iszero, totals) @test isnan(mean_squared_error) end + +@testset "`calibration_curve`" begin + @test binarize_by_threshold(0.2, 0.8) == false + @test binarize_by_threshold(0.2, 0.2) == true + @test binarize_by_threshold(0.3, 0.2) == true + @test binarize_by_threshold.([0, 0, 0], 0.2) == [0, 0, 0] +end + +@testset "Metrics hardening/binarization" begin + predicted_soft_labels = [0.51 0.49 + 0.49 0.51 + 0.1 0.9 + 0.9 0.1 + 0.0 1.0] + elected_hard_labels = [1, 2, 2, 2, 1] + thresholds = [0.25, 0.5, 0.75] + i_class = 2 + default_metrics = get_tradeoff_metrics(predicted_soft_labels, + elected_hard_labels, + i_class; thresholds) + + # Use bogus threshold/hardening function to prove that hardening function is + # used internally + scaled_binarize_by_threshold = (soft, threshold) -> soft >= threshold / 10 + scaled_thresholds = 10 .* thresholds + scaled_metrics = get_tradeoff_metrics(predicted_soft_labels, + elected_hard_labels, + i_class; thresholds=scaled_thresholds, + binarize=scaled_binarize_by_threshold) + @test isequal(default_metrics, scaled_metrics) + + # Discrim calibration + votes = [1 1 1 + 0 2 2 + 1 2 2 + 1 1 2 + 0 1 1] + cal = Lighthouse._calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels, + votes; + thresholds, + class_of_interest_index=i_class) + + scaled_cal = Lighthouse._calculate_optimal_threshold_from_discrimination_calibration(predicted_soft_labels, + votes; + class_of_interest_index=i_class, + thresholds=scaled_thresholds, + binarize=scaled_binarize_by_threshold) + for k in keys(cal) + if k == :threshold + @test cal[k] * 10 == scaled_cal[k] # Should be the same _relative_ threshold + else + @test isequal(cal[k], scaled_cal[k]) + end + end + + conf = Lighthouse.per_class_confusion_statistics(predicted_soft_labels, + elected_hard_labels, thresholds) + scaled_conf = Lighthouse.per_class_confusion_statistics(predicted_soft_labels, + elected_hard_labels, + scaled_thresholds; + binarize=scaled_binarize_by_threshold) + @test isequal(conf, scaled_conf) +end diff --git a/test/runtests.jl b/test/runtests.jl index aa0b7f7..f4e6c7e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using StableRNGs using Lighthouse using Lighthouse: plot_reliability_calibration_curves, plot_pr_curves, plot_roc_curves, plot_kappas, plot_confusion_matrix, - evaluation_metrics_plot, evaluation_metrics + evaluation_metrics_plot, evaluation_metrics, binarize_by_threshold using Base.Threads using CairoMakie using Legolas, Tables @@ -13,7 +13,7 @@ using Arrow # Needs to be set for figures # returning true for showable("image/png", obj) # which TensorBoardLogger.jl uses to determine output -CairoMakie.activate!(type="png") +CairoMakie.activate!(; type="png") plot_results = joinpath(@__DIR__, "plot_results") # Remove any old plots isdir(plot_results) && rm(plot_results; force=true, recursive=true)