Skip to content

Commit

Permalink
evaluation_metrics_row refactor (#70)
Browse files Browse the repository at this point in the history
* Let the refactor commence

* wip

* wip

* wip

* Uncomment

* fix

* reshuffle

* wip

* parity!

* fix

* fix calibration calc

* unnest spearman

* add Curve struct

* Add Curve Arrow type

* support curve type

* unpack curves for refactor test

* update project

* fix tests

* clean up

* ugh

* Test fixes

* use ArrowTypes not Arrow

* Apply suggestions from code review

Co-authored-by: Eric Hanson <[email protected]>

* Add extra docstrings and update docs

* no version bump

* whoops

* fix test

* last? test fix

Co-authored-by: Eric Hanson <[email protected]>
  • Loading branch information
hannahilea and ericphanson authored May 13, 2022
1 parent 253b5bf commit ea55ce0
Show file tree
Hide file tree
Showing 9 changed files with 711 additions and 45 deletions.
8 changes: 6 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ authors = ["Beacon Biosignals, Inc."]
version = "0.14.6"

[deps]
ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
Legolas = "741b9549-f6ed-4911-9fbf-4a1c0c97f0cd"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand All @@ -17,6 +19,8 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
TensorBoardLogger = "899adc3e-224a-11e9-021f-63837185c80f"

[compat]
Arrow = "2.3"
ArrowTypes = "1, 2"
CairoMakie = "0.7"
DataFrames = "1.3"
Legolas = "0.3"
Expand All @@ -27,10 +31,10 @@ TensorBoardLogger = "0.1"
julia = "1.6"

[extras]
Arrow = "69666777-d1a9-59fb-9406-91d4454c9d45"
CairoMakie = "13f3f980-e62b-5c42-98c6-ff1f3baf88f0"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "CairoMakie", "DataFrames", "StableRNGs"]
test = ["Test", "Arrow", "CairoMakie", "StableRNGs"]
16 changes: 15 additions & 1 deletion docs/src/index.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ log_arrays!
LearnLogger
upon
Lighthouse.forward_logs
flush(::LearnLogger)
```

## Performance Metrics
Expand All @@ -73,6 +74,19 @@ ObservationRow
Lighthouse.evaluation_metrics
Lighthouse._evaluation_row_dict
Lighthouse.evaluation_metrics_row
Lighthouse.ClassRow
TradeoffMetricsRow
get_tradeoff_metrics
get_tradeoff_metrics_binary_multirater
HardenedMetricsRow
get_hardened_metrics
get_hardened_metrics_multirater
get_hardened_metrics_multiclass
LabelMetricsRow
get_label_metrics_multirater
get_label_metrics_multirater_multiclass
Lighthouse.refactored_evaluation_metrics_row
Lighthouse._evaluation_row
```

## Utilities
Expand All @@ -81,5 +95,5 @@ Lighthouse.evaluation_metrics_row
majority
Lighthouse.area_under_curve
Lighthouse.area_under_curve_unit_square
flush(::LearnLogger)
Curve
```
10 changes: 8 additions & 2 deletions src/Lighthouse.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,23 @@ using Makie
using Printf
using Legolas
using Tables
using DataFrames
using ArrowTypes

include("row.jl")
export EvaluationRow, ObservationRow
export EvaluationRow, ObservationRow, Curve, TradeoffMetricsRow, HardenedMetricsRow,
LabelMetricsRow, Curve

include("plotting.jl")

include("utilities.jl")
export majority

include("metrics.jl")
export confusion_matrix, accuracy, binary_statistics, cohens_kappa, calibration_curve
export confusion_matrix, accuracy, binary_statistics, cohens_kappa, calibration_curve,
get_tradeoff_metrics, get_tradeoff_metrics_binary_multirater, get_hardened_metrics,
get_hardened_metrics_multirater, get_hardened_metrics_multiclass,
get_label_metrics_multirater, get_label_metrics_multirater_multiclass

include("classifier.jl")
export AbstractClassifier
Expand Down
120 changes: 83 additions & 37 deletions src/learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -277,16 +277,20 @@ Where...
"""
function _calculate_ea_kappas(predicted_hard_labels, elected_hard_labels, class_count)
multiclass = first(cohens_kappa(class_count,
zip(predicted_hard_labels, elected_hard_labels)))
multiclass_kappa = first(cohens_kappa(class_count,
zip(predicted_hard_labels, elected_hard_labels)))

CLASS_VS_ALL_CLASS_COUNT = 2
per_class = map(1:class_count) do class_index
predicted = ((label == class_index) + 1 for label in predicted_hard_labels)
elected = ((label == class_index) + 1 for label in elected_hard_labels)
return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, zip(predicted, elected)))
per_class_kappas = map(1:class_count) do class_index
return _calculate_ea_kappa(predicted_hard_labels, elected_hard_labels, class_index)
end
return (per_class_kappas=per_class, multiclass_kappa=multiclass)
return (; per_class_kappas, multiclass_kappa)
end

function _calculate_ea_kappa(predicted_hard_labels, elected_hard_labels, class_index)
CLASS_VS_ALL_CLASS_COUNT = 2
predicted = ((label == class_index) + 1 for label in predicted_hard_labels)
elected = ((label == class_index) + 1 for label in elected_hard_labels)
return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, zip(predicted, elected)))
end

has_value(x) = !isnothing(x) && !ismissing(x)
Expand All @@ -313,23 +317,11 @@ no two voters rated the same sample. Note that vote entries of `0` are taken to
mean that the voter did not rate that sample.
"""
function _calculate_ira_kappas(votes, classes)
# no votes given or only one expert:
if !has_value(votes) || size(votes, 2) < 2
return (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing)
end

all_hard_label_pairs = Array{Int}(undef, 0, 2)
num_voters = size(votes, 2)
for i_voter in 1:(num_voters - 1)
for j_voter in (i_voter + 1):num_voters
all_hard_label_pairs = vcat(all_hard_label_pairs, votes[:, [i_voter, j_voter]])
end
end
hard_label_pairs = filter(row -> all(row .!= 0), collect(eachrow(all_hard_label_pairs)))
hard_label_pairs = _prep_hard_label_pairs(votes)
length(hard_label_pairs) > 0 ||
return (; per_class_IRA_kappas=missing, multiclass_IRA_kappas=missing) # No common observations voted on
length(hard_label_pairs) < 10 &&
@warn "...only $(length(hard_label_pairs)) in common, potentially questionable IRA results"
@warn "...only $(length(hard_label_pairs)) in common, potentially questionable IRA results"

multiclass_ira = first(cohens_kappa(length(classes), hard_label_pairs))

Expand All @@ -342,6 +334,37 @@ function _calculate_ira_kappas(votes, classes)
return (; per_class_IRA_kappas=per_class_ira, multiclass_IRA_kappas=multiclass_ira)
end

function _prep_hard_label_pairs(votes)
if !has_value(votes) || size(votes, 2) < 2
# no votes given or only one expert
return Tuple{Int64,Int64}[]
end
all_hard_label_pairs = Array{Int}(undef, 0, 2)
num_voters = size(votes, 2)
for i_voter in 1:(num_voters - 1)
for j_voter in (i_voter + 1):num_voters
all_hard_label_pairs = vcat(all_hard_label_pairs, votes[:, [i_voter, j_voter]])
end
end
hard_label_pairs = filter(row -> all(row .!= 0), collect(eachrow(all_hard_label_pairs)))
return hard_label_pairs
end

function _calculate_ira_kappa_multiclass(votes, class_count)
hard_label_pairs = _prep_hard_label_pairs(votes)
length(hard_label_pairs) == 0 && return missing
return first(cohens_kappa(class_count, hard_label_pairs))
end

function _calculate_ira_kappa(votes, class_index)
hard_label_pairs = _prep_hard_label_pairs(votes)
length(hard_label_pairs) == 0 && return missing
CLASS_VS_ALL_CLASS_COUNT = 2
class_v_other_hard_label_pair = map(row -> 1 .+ (row .== class_index),
hard_label_pairs)
return first(cohens_kappa(CLASS_VS_ALL_CLASS_COUNT, class_v_other_hard_label_pair))
end

function _spearman_corr(predicted_soft_labels, elected_soft_labels)
n = length(predicted_soft_labels)
ρ = StatsBase.corspearman(predicted_soft_labels, elected_soft_labels)
Expand Down Expand Up @@ -385,8 +408,10 @@ Where...
- `classes` are the two classes voted on.
"""
function _calculate_spearman_correlation(predicted_soft_labels, votes, classes)
length(classes) > 2 && throw(ArgumentError("Only valid for 2-class problems"))
function _calculate_spearman_correlation(predicted_soft_labels, votes, classes=missing)
if !ismissing(classes)
length(classes) > 2 && throw(ArgumentError("Only valid for 2-class problems"))
end
if !all(x -> x 1, sum(predicted_soft_labels; dims=2))
throw(ArgumentError("Input probabiliities fail softmax assumption"))
end
Expand Down Expand Up @@ -416,6 +441,16 @@ function _calculate_optimal_threshold_from_discrimination_calibration(predicted_
plot_curve_data=(mean.(curve.bins), curve.fractions))
end

function _calculate_discrimination_calibration(predicted_hard_labels, votes;
class_of_interest_index)
elected_probabilities = _elected_probabilities(votes, class_of_interest_index)
bin_count = min(size(votes, 2) + 1, 10)
curve = calibration_curve(elected_probabilities,
predicted_hard_labels .== class_of_interest_index; bin_count)
return (mse=curve.mean_squared_error,
plot_curve_data=(mean.(curve.bins), curve.fractions))
end

function _elected_probabilities(votes, class_of_interest_index)
elected_probabilities = Vector{Float64}()
for sample_votes in eachrow(votes)
Expand All @@ -442,13 +477,17 @@ end

function _get_optimal_threshold_from_ROC(per_class_roc_curves; thresholds,
class_of_interest_index)
return _get_optimal_threshold_from_ROC(per_class_roc_curves[class_of_interest_index],
thresholds)
end

function _get_optimal_threshold_from_ROC(roc_curve, thresholds)
dist = (p1, p2) -> sqrt((p1[1] - p2[1])^2 + (p1[2] - p2[2])^2)
min = Inf
curr_counter = 1
opt_point = nothing
threshold_idx = 1
for point in zip(per_class_roc_curves[class_of_interest_index][1],
per_class_roc_curves[class_of_interest_index][2])
for point in zip(roc_curve[1], roc_curve[2])
d = dist((0, 1), point)
if d < min
min = d
Expand Down Expand Up @@ -703,19 +742,26 @@ end
function per_class_confusion_statistics(predicted_soft_labels::AbstractMatrix,
elected_hard_labels::AbstractVector, thresholds)
class_count = size(predicted_soft_labels, 2)
confusions = [[confusion_matrix(2) for _ in 1:length(thresholds)]
for _ in 1:class_count]
for class_index in 1:class_count
for label_index in 1:length(elected_hard_labels)
predicted_soft_label = predicted_soft_labels[label_index, class_index]
elected = (elected_hard_labels[label_index] == class_index) + 1
for (threshold_index, threshold) in enumerate(thresholds)
predicted = (predicted_soft_label >= threshold) + 1
confusions[class_index][threshold_index][predicted, elected] += 1
end
return map(1:class_count) do i
return per_threshold_confusion_statistics(predicted_soft_labels,
elected_hard_labels,
thresholds, i)
end
end

function per_threshold_confusion_statistics(predicted_soft_labels::AbstractMatrix,
elected_hard_labels::AbstractVector, thresholds,
class_index)
confusions = [confusion_matrix(2) for _ in 1:length(thresholds)]
for label_index in 1:length(elected_hard_labels)
predicted_soft_label = predicted_soft_labels[label_index, class_index]
elected = (elected_hard_labels[label_index] == class_index) + 1
for (threshold_index, threshold) in enumerate(thresholds)
predicted = (predicted_soft_label >= threshold) + 1
confusions[threshold_index][predicted, elected] += 1
end
end
return [binary_statistics.(confusions[i], 2) for i in 1:class_count]
return binary_statistics.(confusions, 2)
end

#####
Expand Down
Loading

0 comments on commit ea55ce0

Please sign in to comment.