From e829617c5b30654a0c9765390e541e749b7aae33 Mon Sep 17 00:00:00 2001 From: a-cakir <103209929+a-cakir@users.noreply.github.com> Date: Fri, 9 Dec 2022 18:56:15 -0500 Subject: [PATCH] Upgrade to Legolas 0.5 (#97) * Project file * update lighthouse eval * update lightouse observation * update lightouse class * update lightouse labeled metrics * update lightouse hardened metrics * update lightouse tradeoff metrics * formatting * comment for evaluation * comment for observation * upgrade comments * small changes based on huddle * export and legolas * Starting updating tests * Use Legolas v0.5 records * Prefer using term record to row * Fix `_values_or_missing` * Address MakieLayout deprecation * Fix mistake where rho was replace with p * Fix `test_evaluation_metrics_roundtrip` * Use `IOBuffer` for `roundtrip_row` * Documentation update * Documentation changes * Use modern Makie which deprecates MakieLayout * updates documentation failures updating Row to V1 * format * workflow * run format * un-mangle * Update docs/Project.toml * docstring format * deprecation (with error) * yasg demands return Co-authored-by: Curtis Vogt Co-authored-by: Nader Bagherzadeh Co-authored-by: Dave Kleinschmidt --- .github/workflows/format.yml | 37 ++++ Project.toml | 6 +- docs/Project.toml | 1 - docs/make.jl | 4 +- docs/src/index.md | 18 +- docs/src/plotting.md | 4 +- format/Manifest.toml | 226 +++++++++++++++++++ format/Project.toml | 5 + format/run.jl | 14 ++ src/LearnLogger.jl | 5 +- src/Lighthouse.jl | 8 +- src/deprecations.jl | 5 + src/learn.jl | 12 +- src/metrics.jl | 231 ++++++++++---------- src/plotting.jl | 82 ++++--- src/row.jl | 407 ++++++++++++++++------------------- test/deprecations.jl | 3 + test/learn.jl | 29 ++- test/metrics.jl | 3 +- test/plotting.jl | 65 +++--- test/row.jl | 55 ++--- test/runtests.jl | 24 +-- 22 files changed, 771 insertions(+), 473 deletions(-) create mode 100644 .github/workflows/format.yml create mode 100644 format/Manifest.toml create mode 100644 format/Project.toml create mode 100644 format/run.jl create mode 100644 src/deprecations.jl create mode 100644 test/deprecations.jl diff --git a/.github/workflows/format.yml b/.github/workflows/format.yml new file mode 100644 index 0000000..cac310d --- /dev/null +++ b/.github/workflows/format.yml @@ -0,0 +1,37 @@ +name: YASG-enforcer +on: + push: + branches: + - 'main' + tags: '*' + pull_request: + types: [opened, synchronize, reopened, ready_for_review] + # note: keep in sync with `format/run.jl` + paths: + - 'src/**' + - 'test/**' + - '.github/workflows/format.yml' + - 'format/**' +jobs: + format-check: + name: YASG Enforcement (Julia ${{ matrix.julia-version }} - ${{ github.event_name }}) + # Run on push's or non-draft PRs + if: (github.event_name == 'push') || (github.event.pull_request.draft == false) + runs-on: ubuntu-latest + strategy: + matrix: + julia-version: [1.8] + steps: + - uses: julia-actions/setup-julia@latest + with: + version: ${{ matrix.julia-version }} + - uses: actions/checkout@v1 + - name: Instantiate `format` environment and format + run: | + julia --project=format -e 'using Pkg; Pkg.instantiate()' + julia --project=format 'format/run.jl' + - uses: reviewdog/action-suggester@v1 + if: github.event_name == 'pull_request' + with: + tool_name: JuliaFormatter + fail_on_error: true diff --git a/Project.toml b/Project.toml index b071024..4cf44f5 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Lighthouse" uuid = "ac2c24cd-07f0-4848-96b2-1b82c3ea0e59" authors = ["Beacon Biosignals, Inc."] -version = "0.14.16" +version = "0.15.0" [deps] ArrowTypes = "31f734f8-188a-4ce0-8406-c8a06bd891cd" @@ -23,8 +23,8 @@ Arrow = "2.3" ArrowTypes = "1, 2" CairoMakie = "0.7, 0.8" DataFrames = "1.3" -Legolas = "0.3" -Makie = "0.16.5, 0.17, 0.18" +Legolas = "0.5" +Makie = "0.17.4, 0.18" StatsBase = "0.33" Tables = "1.7" TensorBoardLogger = "0.1" diff --git a/docs/Project.toml b/docs/Project.toml index fb76611..54f344c 100644 --- a/docs/Project.toml +++ b/docs/Project.toml @@ -4,5 +4,4 @@ Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4" StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3" [compat] -CairoMakie = "0.7.4" Documenter = "0.25" diff --git a/docs/make.jl b/docs/make.jl index 0d3fbe7..ebee288 100644 --- a/docs/make.jl +++ b/docs/make.jl @@ -7,7 +7,7 @@ makedocs(; modules=[Lighthouse], sitename="Lighthouse", "Plotting" => "plotting.md"], # makes docs fail hard if there is any error building the examples, # so we don't just miss a build failure! - strict = true) + strict=true) -deploydocs(repo="github.com/beacon-biosignals/Lighthouse.jl.git", +deploydocs(; repo="github.com/beacon-biosignals/Lighthouse.jl.git", devbranch="main", push_preview=true) diff --git a/docs/src/index.md b/docs/src/index.md index 256dae2..69cbe43 100644 --- a/docs/src/index.md +++ b/docs/src/index.md @@ -66,23 +66,23 @@ accuracy binary_statistics cohens_kappa calibration_curve -EvaluationRow -ObservationRow +EvaluationV1 +ObservationV1 Lighthouse.evaluation_metrics -Lighthouse._evaluation_row_dict -Lighthouse.evaluation_metrics_row -Lighthouse.ClassRow -TradeoffMetricsRow +Lighthouse._evaluation_dict +Lighthouse.evaluation_metrics_record +Lighthouse.ClassV1 +TradeoffMetricsV1 get_tradeoff_metrics get_tradeoff_metrics_binary_multirater -HardenedMetricsRow +HardenedMetricsV1 get_hardened_metrics get_hardened_metrics_multirater get_hardened_metrics_multiclass -LabelMetricsRow +LabelMetricsV1 get_label_metrics_multirater get_label_metrics_multirater_multiclass -Lighthouse._evaluation_row +Lighthouse._evaluation_record Lighthouse._calculate_ea_kappas Lighthouse._calculate_ira_kappas Lighthouse._calculate_spearman_correlation diff --git a/docs/src/plotting.md b/docs/src/plotting.md index 57f2c33..58523da 100644 --- a/docs/src/plotting.md +++ b/docs/src/plotting.md @@ -274,8 +274,8 @@ data["optimal_threshold"] = missing evaluation_metrics_plot(data) ``` -Plots can also be generated directly from an `EvaluationRow`: +Plots can also be generated directly from an `EvaluationV1`: ```@example 1 -data_row = EvaluationRow(data) +data_row = EvaluationV1(data) evaluation_metrics_plot(data_row) ``` diff --git a/format/Manifest.toml b/format/Manifest.toml new file mode 100644 index 0000000..9320da7 --- /dev/null +++ b/format/Manifest.toml @@ -0,0 +1,226 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.8.3" +manifest_format = "2.0" +project_hash = "30b405be1c677184b7703a9bfb3d2100029ccad0" + +[[deps.ArgTools]] +uuid = "0dad84c5-d112-42e6-8d28-ef12dabb789f" +version = "1.1.1" + +[[deps.Artifacts]] +uuid = "56f22d72-fd6d-98f1-02f0-08ddc0907c33" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" + +[[deps.CSTParser]] +deps = ["Tokenize"] +git-tree-sha1 = "3ddd48d200eb8ddf9cb3e0189fc059fd49b97c1f" +uuid = "00ebfdb7-1f24-5e51-bd34-a7502290713f" +version = "3.3.6" + +[[deps.CommonMark]] +deps = ["Crayons", "JSON", "URIs"] +git-tree-sha1 = "86cce6fd164c26bad346cc51ca736e692c9f553c" +uuid = "a80b9123-70ca-4bc0-993e-6e3bcb318db6" +version = "0.8.7" + +[[deps.Compat]] +deps = ["Dates", "LinearAlgebra", "UUIDs"] +git-tree-sha1 = "00a2cccc7f098ff3b66806862d275ca3db9e6e5a" +uuid = "34da2185-b29b-5c13-b0c7-acf172513d20" +version = "4.5.0" + +[[deps.CompilerSupportLibraries_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "e66e0078-7015-5450-92f7-15fbd957f2ae" +version = "0.5.2+0" + +[[deps.Crayons]] +git-tree-sha1 = "249fe38abf76d48563e2f4556bebd215aa317e15" +uuid = "a8cc5b0e-0ffa-5ad4-8c14-923d3ee1735f" +version = "4.1.1" + +[[deps.DataStructures]] +deps = ["Compat", "InteractiveUtils", "OrderedCollections"] +git-tree-sha1 = "d1fff3a548102f48987a52a2e0d114fa97d730f0" +uuid = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8" +version = "0.18.13" + +[[deps.Dates]] +deps = ["Printf"] +uuid = "ade2ca70-3891-5945-98fb-dc099432e06a" + +[[deps.Downloads]] +deps = ["ArgTools", "FileWatching", "LibCURL", "NetworkOptions"] +uuid = "f43a241f-c20a-4ad4-852c-f6b1247861c6" +version = "1.6.0" + +[[deps.FileWatching]] +uuid = "7b1f6079-737a-58dc-b8bc-7a2ca5c1b5ee" + +[[deps.Glob]] +git-tree-sha1 = "4df9f7e06108728ebf00a0a11edee4b29a482bb2" +uuid = "c27321d9-0574-5035-807b-f59d2c89b15c" +version = "1.3.0" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" + +[[deps.JSON]] +deps = ["Dates", "Mmap", "Parsers", "Unicode"] +git-tree-sha1 = "3c837543ddb02250ef42f4738347454f95079d4e" +uuid = "682c06a0-de6a-54ab-a142-c8b1cf79cde6" +version = "0.21.3" + +[[deps.JuliaFormatter]] +deps = ["CSTParser", "CommonMark", "DataStructures", "Glob", "Pkg", "Tokenize"] +git-tree-sha1 = "76ee67858b65133b2460b0eebf52e49950bb90a3" +uuid = "98e50ef6-434e-11e9-1051-2b60c6c9e899" +version = "1.0.16" + +[[deps.LibCURL]] +deps = ["LibCURL_jll", "MozillaCACerts_jll"] +uuid = "b27032c2-a3e7-50c8-80cd-2d36dbcbfd21" +version = "0.6.3" + +[[deps.LibCURL_jll]] +deps = ["Artifacts", "LibSSH2_jll", "Libdl", "MbedTLS_jll", "Zlib_jll", "nghttp2_jll"] +uuid = "deac9b47-8bc7-5906-a0fe-35ac56dc84c0" +version = "7.84.0+0" + +[[deps.LibGit2]] +deps = ["Base64", "NetworkOptions", "Printf", "SHA"] +uuid = "76f85450-5226-5b5a-8eaa-529ad045b433" + +[[deps.LibSSH2_jll]] +deps = ["Artifacts", "Libdl", "MbedTLS_jll"] +uuid = "29816b5a-b9ab-546f-933c-edad1886dfa8" +version = "1.10.2+0" + +[[deps.Libdl]] +uuid = "8f399da3-3557-5675-b5ff-fb832c97cbdb" + +[[deps.LinearAlgebra]] +deps = ["Libdl", "libblastrampoline_jll"] +uuid = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" + +[[deps.MbedTLS_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "c8ffd9c3-330d-5841-b78e-0817d7145fa1" +version = "2.28.0+0" + +[[deps.Mmap]] +uuid = "a63ad114-7e13-5084-954f-fe012c677804" + +[[deps.MozillaCACerts_jll]] +uuid = "14a3606d-f60d-562e-9121-12d972cd8159" +version = "2022.2.1" + +[[deps.NetworkOptions]] +uuid = "ca575930-c2e3-43a9-ace4-1e988b2c1908" +version = "1.2.0" + +[[deps.OpenBLAS_jll]] +deps = ["Artifacts", "CompilerSupportLibraries_jll", "Libdl"] +uuid = "4536629a-c528-5b80-bd46-f80d51c5b363" +version = "0.3.20+0" + +[[deps.OrderedCollections]] +git-tree-sha1 = "85f8e6578bf1f9ee0d11e7bb1b1456435479d47c" +uuid = "bac558e1-5e72-5ebc-8fee-abe8a469f55d" +version = "1.4.1" + +[[deps.Parsers]] +deps = ["Dates", "SnoopPrecompile"] +git-tree-sha1 = "b64719e8b4504983c7fca6cc9db3ebc8acc2a4d6" +uuid = "69de0a69-1ddd-5017-9359-2bf0b02dc9f0" +version = "2.5.1" + +[[deps.Pkg]] +deps = ["Artifacts", "Dates", "Downloads", "LibGit2", "Libdl", "Logging", "Markdown", "Printf", "REPL", "Random", "SHA", "Serialization", "TOML", "Tar", "UUIDs", "p7zip_jll"] +uuid = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f" +version = "1.8.0" + +[[deps.Printf]] +deps = ["Unicode"] +uuid = "de0858da-6303-5e67-8744-51eddeeeb8d7" + +[[deps.REPL]] +deps = ["InteractiveUtils", "Markdown", "Sockets", "Unicode"] +uuid = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb" + +[[deps.Random]] +deps = ["SHA", "Serialization"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" + +[[deps.SnoopPrecompile]] +git-tree-sha1 = "f604441450a3c0569830946e5b33b78c928e1a85" +uuid = "66db9d55-30c0-4569-8b51-7e840670fc0c" +version = "1.0.1" + +[[deps.Sockets]] +uuid = "6462fe0b-24de-5631-8697-dd941f90decc" + +[[deps.TOML]] +deps = ["Dates"] +uuid = "fa267f1f-6049-4f14-aa54-33bafae1ed76" +version = "1.0.0" + +[[deps.Tar]] +deps = ["ArgTools", "SHA"] +uuid = "a4e569a6-e804-4fa4-b0f3-eef7a1d5b13e" +version = "1.10.1" + +[[deps.Tokenize]] +git-tree-sha1 = "2b3af135d85d7e70b863540160208fa612e736b9" +uuid = "0796e94c-ce3b-5d07-9a54-7f471281c624" +version = "0.5.24" + +[[deps.URIs]] +git-tree-sha1 = "ac00576f90d8a259f2c9d823e91d1de3fd44d348" +uuid = "5c2747f8-b7ea-4ff2-ba2e-563bfd36b1d4" +version = "1.4.1" + +[[deps.UUIDs]] +deps = ["Random", "SHA"] +uuid = "cf7118a7-6976-5b1a-9a39-7adc72f591a4" + +[[deps.Unicode]] +uuid = "4ec0a83e-493e-50e2-b9ac-8f72acf5a8f5" + +[[deps.Zlib_jll]] +deps = ["Libdl"] +uuid = "83775a58-1f1d-513f-b197-d71354ab007a" +version = "1.2.12+3" + +[[deps.libblastrampoline_jll]] +deps = ["Artifacts", "Libdl", "OpenBLAS_jll"] +uuid = "8e850b90-86db-534c-a0d3-1478176c7d93" +version = "5.1.1+0" + +[[deps.nghttp2_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "8e850ede-7688-5339-a07c-302acd2aaf8d" +version = "1.48.0+0" + +[[deps.p7zip_jll]] +deps = ["Artifacts", "Libdl"] +uuid = "3f19e933-33d8-53b3-aaab-bd5110c3b7a0" +version = "17.4.0+0" diff --git a/format/Project.toml b/format/Project.toml new file mode 100644 index 0000000..71708c8 --- /dev/null +++ b/format/Project.toml @@ -0,0 +1,5 @@ +[deps] +JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899" + +[compat] +JuliaFormatter = "1" diff --git a/format/run.jl b/format/run.jl new file mode 100644 index 0000000..7cb3490 --- /dev/null +++ b/format/run.jl @@ -0,0 +1,14 @@ +using JuliaFormatter + +function main() + perfect = format(joinpath(@__DIR__, ".."); style=YASStyle()) + if perfect + @info "Linting complete - no files altered" + else + @info "Linting complete - files altered" + run(`git status`) + end + return nothing +end + +main() diff --git a/src/LearnLogger.jl b/src/LearnLogger.jl index 4045008..947e413 100644 --- a/src/LearnLogger.jl +++ b/src/LearnLogger.jl @@ -45,8 +45,9 @@ function log_plot!(logger::LearnLogger, field::AbstractString, plot, plot_data) return plot end -function log_line_series!(logger::LearnLogger, field::AbstractString, curves, labels=1:length(curves)) - @warn "`log_line_series!` not implemented for `LearnLogger`" maxlog=1 +function log_line_series!(logger::LearnLogger, field::AbstractString, curves, + labels=1:length(curves)) + @warn "`log_line_series!` not implemented for `LearnLogger`" maxlog = 1 return nothing end diff --git a/src/Lighthouse.jl b/src/Lighthouse.jl index db5dcdd..9154d82 100644 --- a/src/Lighthouse.jl +++ b/src/Lighthouse.jl @@ -6,14 +6,14 @@ using StatsBase: StatsBase using TensorBoardLogger using Makie using Printf -using Legolas +using Legolas: Legolas, @schema, @version, lift using Tables using DataFrames using ArrowTypes include("row.jl") -export EvaluationRow, ObservationRow, Curve, TradeoffMetricsRow, HardenedMetricsRow, - LabelMetricsRow, Curve +export EvaluationV1, ObservationV1, Curve, TradeoffMetricsV1, HardenedMetricsV1, + LabelMetricsV1 include("plotting.jl") @@ -38,4 +38,6 @@ export learn!, upon, evaluate!, predict! export log_event!, log_line_series!, log_plot!, step_logger!, log_value!, log_values! export log_array!, log_arrays! +include("deprecations.jl") + end # module diff --git a/src/deprecations.jl b/src/deprecations.jl new file mode 100644 index 0000000..b1dc010 --- /dev/null +++ b/src/deprecations.jl @@ -0,0 +1,5 @@ +function evaluation_metrics_row(args...; kwargs...) + error("`Lighthouse.evaluation_metrics_row` has been removed in favor of " * + "`Lighthouse.evaluation_metrics_record`.") + return nothing +end diff --git a/src/learn.jl b/src/learn.jl index 16159df..4b98ae2 100644 --- a/src/learn.jl +++ b/src/learn.jl @@ -28,7 +28,7 @@ log_value!(logger, field::AbstractString, value) Logs a series plot to `logger` under `field`, where... -- `curves` is an iterable of the form `Tuple{Vector{Real},Vector{Real}}`, where each tuple contains `(x-values, y-values)`, as in the `Lighthouse.EvaluationRow` field `per_class_roc_curves` +- `curves` is an iterable of the form `Tuple{Vector{Real},Vector{Real}}`, where each tuple contains `(x-values, y-values)`, as in the `Lighthouse.EvaluationV1` field `per_class_roc_curves` - `labels` is the class label for each curve, which defaults to the numeric index of each curve. """ log_line_series!(logger, field::AbstractString, curves, labels=1:length(curves)) @@ -92,12 +92,12 @@ end """ log_evaluation_row!(logger, field::AbstractString, metrics) -From fields in [`EvaluationRow`](@ref), generate and plot the composite [`evaluation_metrics_plot`](@ref) +From fields in [`EvaluationV1`](@ref), generate and plot the composite [`evaluation_metrics_plot`](@ref) as well as `spearman_correlation` (if present). """ function log_evaluation_row!(logger, field::AbstractString, metrics) metrics_plot = evaluation_metrics_plot(metrics) - metrics_dict = _evaluation_row_dict(metrics) + metrics_dict = _evaluation_dict(metrics) log_plot!(logger, field, metrics_plot, metrics_dict) if haskey(metrics_dict, "spearman_correlation") sp_field = replace(field, "metrics" => "spearman_correlation") @@ -230,9 +230,9 @@ function evaluate!(predicted_hard_labels::AbstractVector, _validate_threshold_class(optimal_threshold_class, classes) log_resource_info!(logger, logger_prefix; suffix=logger_suffix) do - metrics = evaluation_metrics_row(predicted_hard_labels, predicted_soft_labels, - elected_hard_labels, classes, thresholds; - votes, optimal_threshold_class) + metrics = evaluation_metrics_record(predicted_hard_labels, predicted_soft_labels, + elected_hard_labels, classes, thresholds; + votes, optimal_threshold_class) log_evaluation_row!(logger, logger_prefix * "/metrics" * logger_suffix, metrics) return nothing diff --git a/src/metrics.jl b/src/metrics.jl index 0eeadf0..d838445 100644 --- a/src/metrics.jl +++ b/src/metrics.jl @@ -198,7 +198,7 @@ end get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_index; thresholds, binarize=binarize_by_threshold, class_labels=missing) -Return [`TradeoffMetricsRow`] calculated for the given `class_index`, with the following +Return [`TradeoffMetricsV1`] calculated for the given `class_index`, with the following fields guaranteed to be non-missing: `roc_curve`, `roc_auc`, pr_curve`, `reliability_calibration_curve`, `reliability_calibration_score`.` $(BINARIZE_NOTE) (`class_index`). @@ -219,16 +219,16 @@ function get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, class_ reliability_calibration.fractions) reliability_calibration_score = reliability_calibration.mean_squared_error - return TradeoffMetricsRow(; class_index, class_labels, roc_curve, - roc_auc=area_under_curve(roc_curve...), pr_curve, - reliability_calibration_curve, reliability_calibration_score) + return TradeoffMetricsV1(; class_index, class_labels, roc_curve, + roc_auc=area_under_curve(roc_curve...), pr_curve, + reliability_calibration_curve, reliability_calibration_score) end """ get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_hard_labels, class_index; thresholds, binarize=binarize_by_threshold, class_labels=missing) -Return [`TradeoffMetricsRow`] calculated for the given `class_index`. In addition +Return [`TradeoffMetricsV1`] calculated for the given `class_index`. In addition to metrics calculated by [`get_tradeoff_metrics`](@ref), additionally calculates `spearman_correlation`-based metrics. $(BINARIZE_NOTE) (`class_index`). """ @@ -243,29 +243,29 @@ function get_tradeoff_metrics_binary_multirater(predicted_soft_labels, elected_h (; spearman_correlation=corr.ρ, spearman_correlation_ci_upper=corr.ci_upper, spearman_correlation_ci_lower=corr.ci_lower, n_samples=corr.n)) - return TradeoffMetricsRow(; row...) + return TradeoffMetricsV1(; row...) end """ get_hardened_metrics(predicted_hard_labels, elected_hard_labels, class_index; class_labels=missing) -Return [`HardenedMetricsRow`] calculated for the given `class_index`, with the following +Return [`HardenedMetricsV1`] calculated for the given `class_index`, with the following field guaranteed to be non-missing: expert-algorithm agreement (`ea_kappa`). """ function get_hardened_metrics(predicted_hard_labels, elected_hard_labels, class_index; class_labels=missing) - return HardenedMetricsRow(; class_index, class_labels, - ea_kappa=_calculate_ea_kappa(predicted_hard_labels, - elected_hard_labels, - class_index)) + return HardenedMetricsV1(; class_index, class_labels, + ea_kappa=_calculate_ea_kappa(predicted_hard_labels, + elected_hard_labels, + class_index)) end """ get_hardened_metrics_multirater(predicted_hard_labels, elected_hard_labels, class_index; class_labels=missing) -Return [`HardenedMetricsRow`] calculated for the given `class_index`. In addition +Return [`HardenedMetricsV1`] calculated for the given `class_index`. In addition to metrics calculated by [`get_hardened_metrics`](@ref), additionally calculates `discrimination_calibration_curve` and `discrimination_calibration_score`. """ @@ -278,14 +278,14 @@ function get_hardened_metrics_multirater(predicted_hard_labels, elected_hard_lab row = Tables.rowmerge(basic_row, (; discrimination_calibration_curve=cal.plot_curve_data, discrimination_calibration_score=cal.mse)) - return HardenedMetricsRow(; row...) + return HardenedMetricsV1(; row...) end """ get_hardened_metrics_multiclass(predicted_hard_labels, elected_hard_labels, class_count; class_labels=missing) -Return [`HardenedMetricsRow`] calculated over all `class_count` classes. Calculates +Return [`HardenedMetricsV1`] calculated over all `class_count` classes. Calculates expert-algorithm agreement (`ea_kappa`) over all classes, as well as the multiclass `confusion_matrix`. """ @@ -293,17 +293,17 @@ function get_hardened_metrics_multiclass(predicted_hard_labels, elected_hard_lab class_count; class_labels=missing) ea_kappa = first(cohens_kappa(class_count, zip(predicted_hard_labels, elected_hard_labels))) - return HardenedMetricsRow(; class_index=:multiclass, class_labels, - confusion_matrix=confusion_matrix(class_count, - zip(predicted_hard_labels, - elected_hard_labels)), - ea_kappa) + return HardenedMetricsV1(; class_index=:multiclass, class_labels, + confusion_matrix=confusion_matrix(class_count, + zip(predicted_hard_labels, + elected_hard_labels)), + ea_kappa) end """ get_label_metrics_multirater(votes, class_index; class_labels=missing) -Return [`LabelMetricsRow`] calculated for the given `class_index`, with the following +Return [`LabelMetricsV1`] calculated for the given `class_index`, with the following field guaranteed to be non-missing: `per_expert_discrimination_calibration_curves`, `per_expert_discrimination_calibration_scores`, interrater-agreement (`ira_kappa`). """ @@ -314,23 +314,23 @@ function get_label_metrics_multirater(votes, class_index; class_labels=missing) class_of_interest_index=class_index) per_expert_discrimination_calibration_curves = expert_cal.plot_curve_data per_expert_discrimination_calibration_scores = expert_cal.mse - return LabelMetricsRow(; class_index, class_labels, - per_expert_discrimination_calibration_curves, - per_expert_discrimination_calibration_scores, - ira_kappa=_calculate_ira_kappa(votes, class_index)) + return LabelMetricsV1(; class_index, class_labels, + per_expert_discrimination_calibration_curves, + per_expert_discrimination_calibration_scores, + ira_kappa=_calculate_ira_kappa(votes, class_index)) end """ get_label_metrics_multirater_multiclass(votes, class_count; class_labels=missing) -Return [`LabelMetricsRow`] calculated over all `class_count` classes. Calculates +Return [`LabelMetricsV1`] calculated over all `class_count` classes. Calculates the multiclass interrater agreement (`ira_kappa`). """ function get_label_metrics_multirater_multiclass(votes, class_count; class_labels=missing) size(votes, 2) > 1 || throw(ArgumentError("Input `votes` is not multirater (`size(votes) == $(size(votes))`)")) - return LabelMetricsRow(; class_index=:multiclass, class_labels, - ira_kappa=_calculate_ira_kappa_multiclass(votes, class_count)) + return LabelMetricsV1(; class_index=:multiclass, class_labels, + ira_kappa=_calculate_ira_kappa_multiclass(votes, class_count)) end ##### @@ -340,30 +340,30 @@ end """ evaluation_metrics(args...; optimal_threshold_class=nothing, kwargs...) -Return [`evaluation_metrics_row`](@ref) after converting output `EvaluationRow` -into a `Dict`. For argument details, see [`evaluation_metrics_row`](@ref). +Return [`evaluation_metrics_record`](@ref) after converting output `EvaluationV1` +into a `Dict`. For argument details, see [`evaluation_metrics_record`](@ref). """ function evaluation_metrics(args...; optimal_threshold_class=nothing, kwargs...) - row = evaluation_metrics_row(args...; - optimal_threshold_class=something(optimal_threshold_class, - missing), kwargs...) - return _evaluation_row_dict(row) -end - -""" - evaluation_metrics_row(observation_table, classes, thresholds=0.0:0.01:1.0; - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Missing,Nothing,Integer}=missing) - evaluation_metrics_row(predicted_hard_labels::AbstractVector, - predicted_soft_labels::AbstractMatrix, - elected_hard_labels::AbstractVector, - classes, - thresholds=0.0:0.01:1.0; - votes::Union{Nothing,Missing,AbstractMatrix}=nothing, - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Missing,Nothing,Integer}=missing) - -Returns `EvaluationRow` containing a battery of classifier performance + row = evaluation_metrics_record(args...; + optimal_threshold_class=something(optimal_threshold_class, + missing), kwargs...) + return _evaluation_dict(row) +end + +""" + evaluation_metrics_record(observation_table, classes, thresholds=0.0:0.01:1.0; + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Missing,Nothing,Integer}=missing) + evaluation_metrics_record(predicted_hard_labels::AbstractVector, + predicted_soft_labels::AbstractMatrix, + elected_hard_labels::AbstractVector, + classes, + thresholds=0.0:0.01:1.0; + votes::Union{Nothing,Missing,AbstractMatrix}=nothing, + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Missing,Nothing,Integer}=missing) + +Returns `EvaluationV1` containing a battery of classifier performance metrics that each compare `predicted_soft_labels` and/or `predicted_hard_labels` agaist `elected_hard_labels`. @@ -396,36 +396,37 @@ Where... ignored and new `predicted_hard_labels` will be recalculated from the new threshold. This is only a valid parameter when `length(classes) == 2` -Alternatively, an `observation_table` that consists of rows of type [`ObservationRow`](@ref) +Alternatively, an `observation_table` that consists of rows of type [`ObservationV1`](@ref) can be passed in in place of `predicted_soft_labels`,`predicted_hard_labels`,`elected_hard_labels`, and `votes`. $(BINARIZE_NOTE). See also [`evaluation_metrics_plot`](@ref). """ -function evaluation_metrics_row(observation_table, classes, thresholds=0.0:0.01:1.0; - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Missing,Nothing,Integer}=missing, - binarize=binarize_by_threshold) +function evaluation_metrics_record(observation_table, classes, thresholds=0.0:0.01:1.0; + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Missing,Nothing,Integer}=missing, + binarize=binarize_by_threshold) inputs = _observation_table_to_inputs(observation_table) - return evaluation_metrics_row(inputs.predicted_hard_labels, - inputs.predicted_soft_labels, inputs.elected_hard_labels, - classes, thresholds; inputs.votes, strata, - optimal_threshold_class, binarize) -end - -function evaluation_metrics_row(predicted_hard_labels::AbstractVector, - predicted_soft_labels::AbstractMatrix, - elected_hard_labels::AbstractVector, classes, - thresholds=0.0:0.01:1.0; - votes::Union{Nothing,Missing,AbstractMatrix}=nothing, - strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, - optimal_threshold_class::Union{Missing,Nothing,Integer}=missing, - binarize=binarize_by_threshold) + return evaluation_metrics_record(inputs.predicted_hard_labels, + inputs.predicted_soft_labels, + inputs.elected_hard_labels, + classes, thresholds; inputs.votes, strata, + optimal_threshold_class, binarize) +end + +function evaluation_metrics_record(predicted_hard_labels::AbstractVector, + predicted_soft_labels::AbstractMatrix, + elected_hard_labels::AbstractVector, classes, + thresholds=0.0:0.01:1.0; + votes::Union{Nothing,Missing,AbstractMatrix}=nothing, + strata::Union{Nothing,AbstractVector{Set{T}} where T}=nothing, + optimal_threshold_class::Union{Missing,Nothing,Integer}=missing, + binarize=binarize_by_threshold) class_labels = string.(collect(classes)) # Plots.jl expects this to be an `AbstractVector` class_indices = 1:length(classes) # Step 1: Calculate all metrics that do not require hardened predictions - # In our `evaluation_metrics_row` we special-case multirater binary classification, + # In our `evaluation_metrics_record` we special-case multirater binary classification, # so do that here as well. tradeoff_metrics_rows = if length(classes) == 2 && has_value(votes) map(ic -> get_tradeoff_metrics_binary_multirater(predicted_soft_labels, @@ -482,7 +483,7 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, # Step 4: Calculate all metrics derived directly from labels (does not depend on # predictions) - labels_metrics_table = LabelMetricsRow[] + labels_metrics_table = LabelMetricsV1[] if has_value(votes) && size(votes, 2) > 1 labels_metrics_table = map(c -> get_label_metrics_multirater(votes, c), class_indices) @@ -498,9 +499,9 @@ function evaluation_metrics_row(predicted_hard_labels::AbstractVector, elected_hard_labels, length(classes), strata) : missing - return _evaluation_row(tradeoff_metrics_rows, hardened_metrics_table, - labels_metrics_table; optimal_threshold_class, class_labels, - thresholds, optimal_threshold, stratified_kappas) + return _evaluation_record(tradeoff_metrics_rows, hardened_metrics_table, + labels_metrics_table; optimal_threshold_class, class_labels, + thresholds, optimal_threshold, stratified_kappas) end function _split_classes_from_multiclass(table) @@ -523,32 +524,39 @@ end function _values_or_missing(values) has_value(values) || return missing - return all(ismissing, values) ? missing : values + return if Base.IteratorSize(values) == Base.HasShape{0}() + values + else + T = nonmissingtype(eltype(values)) + all(!ismissing, values) ? convert(Array{T}, values) : missing + end end -_unpack_curves(curve::Union{Missing,Curve}) = ismissing(curve) ? missing : Tuple(curve) +_unpack_curves(curve::Missing) = missing +_unpack_curves(curve::Curve) = Tuple(curve) _unpack_curves(curves::AbstractVector{Curve}) = Tuple.(curves) """ - _evaluation_row(tradeoff_metrics_table, hardened_metrics_table, label_metrics_table; - optimal_threshold_class=missing, class_labels, thresholds, - optimal_threshold, stratified_kappas=missing) + _evaluation_record(tradeoff_metrics_table, hardened_metrics_table, label_metrics_table; + optimal_threshold_class=missing, class_labels, thresholds, + optimal_threshold, stratified_kappas=missing) -Helper function to create an `EvaluationRow` from tables of constituent Metrics schemas, -to support [`evaluation_metrics_row`](@ref): -- `tradeoff_metrics_table`: table of [`TradeoffMetricsRow`](@ref)s -- `hardened_metrics_table`: table of [`HardenedMetricsRow`](@ref)s -- `label_metrics_table`: table of [`LabelMetricsRow`](@ref)s +Helper function to create an `EvaluationV1` from tables of constituent Metrics schemas, +to support [`evaluation_metrics_record`](@ref): +- `tradeoff_metrics_table`: table of [`TradeoffMetricsV1`](@ref)s +- `hardened_metrics_table`: table of [`HardenedMetricsV1`](@ref)s +- `label_metrics_table`: table of [`LabelMetricsV1`](@ref)s """ -function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table, - label_metrics_table; optimal_threshold_class=missing, class_labels, - thresholds, optimal_threshold, stratified_kappas=missing) +function _evaluation_record(tradeoff_metrics_table, hardened_metrics_table, + label_metrics_table; optimal_threshold_class=missing, + class_labels, + thresholds, optimal_threshold, stratified_kappas=missing) tradeoff_rows, _ = _split_classes_from_multiclass(tradeoff_metrics_table) hardened_rows, hardened_multi = _split_classes_from_multiclass(hardened_metrics_table) label_rows, labels_multi = _split_classes_from_multiclass(label_metrics_table) # Due to special casing, the following metrics should only be present - # in the resultant `EvaluationRow` if `optimal_threshold_class` is present + # in the resultant `EvaluationV1` if `optimal_threshold_class` is present discrimination_calibration_curve = missing discrimination_calibration_score = missing if has_value(optimal_threshold_class) @@ -559,7 +567,7 @@ function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table, end # Similarly, the following metrics should only be present - # in the resultant `EvaluationRow` when doing multirater evaluation + # in the resultant `EvaluationV1` when doing multirater evaluation per_expert_discrimination_calibration_curves = missing per_expert_discrimination_calibration_scores = missing if has_value(label_rows) && has_value(optimal_threshold_class) @@ -585,30 +593,29 @@ function _evaluation_row(tradeoff_metrics_table, hardened_metrics_table, ci_lower=row.spearman_correlation_ci_lower, ci_upper=row.spearman_correlation_ci_upper) end - return EvaluationRow(; - # ...from hardened_metrics_table - confusion_matrix=_values_or_missing(hardened_multi.confusion_matrix), - multiclass_kappa=_values_or_missing(hardened_multi.ea_kappa), - per_class_kappas=_values_or_missing(hardened_rows.ea_kappa), - discrimination_calibration_curve=_unpack_curves(discrimination_calibration_curve), - discrimination_calibration_score, - - # ...from tradeoff_metrics_table - per_class_roc_curves=_unpack_curves(_values_or_missing(tradeoff_rows.roc_curve)), - per_class_roc_aucs=_values_or_missing(tradeoff_rows.roc_auc), - per_class_pr_curves=_unpack_curves(_values_or_missing(tradeoff_rows.pr_curve)), - spearman_correlation, - per_class_reliability_calibration_curves=_unpack_curves(_values_or_missing(tradeoff_rows.reliability_calibration_curve)), - per_class_reliability_calibration_scores=_values_or_missing(tradeoff_rows.reliability_calibration_score), - - # from label_metrics_table - per_expert_discrimination_calibration_curves, - multiclass_IRA_kappas, per_class_IRA_kappas, - per_expert_discrimination_calibration_scores, - - # from kwargs: - optimal_threshold_class=_values_or_missing(optimal_threshold_class), - class_labels, thresholds, optimal_threshold, stratified_kappas) + return EvaluationV1(; # ...from hardened_metrics_table + confusion_matrix=_values_or_missing(hardened_multi.confusion_matrix), + multiclass_kappa=_values_or_missing(hardened_multi.ea_kappa), + per_class_kappas=_values_or_missing(hardened_rows.ea_kappa), + discrimination_calibration_curve=_unpack_curves(discrimination_calibration_curve), + discrimination_calibration_score, + + # ...from tradeoff_metrics_table + per_class_roc_curves=_unpack_curves(_values_or_missing(tradeoff_rows.roc_curve)), + per_class_roc_aucs=_values_or_missing(tradeoff_rows.roc_auc), + per_class_pr_curves=_unpack_curves(_values_or_missing(tradeoff_rows.pr_curve)), + spearman_correlation, + per_class_reliability_calibration_curves=_unpack_curves(_values_or_missing(tradeoff_rows.reliability_calibration_curve)), + per_class_reliability_calibration_scores=_values_or_missing(tradeoff_rows.reliability_calibration_score), + + # from label_metrics_table + per_expert_discrimination_calibration_curves, + multiclass_IRA_kappas, per_class_IRA_kappas, + per_expert_discrimination_calibration_scores, + + # from kwargs: + optimal_threshold_class=_values_or_missing(optimal_threshold_class), + class_labels, thresholds, optimal_threshold, stratified_kappas) end ##### diff --git a/src/plotting.jl b/src/plotting.jl index 2d66606..77c5874 100644 --- a/src/plotting.jl +++ b/src/plotting.jl @@ -44,7 +44,8 @@ function high_contrast(background_color::Colorant, target_color::Colorant; # chose from whole lightness spectrum lchoices=range(0; stop=100, length=15)) target = LCHab(target_color) - color = distinguishable_colors(1, [RGB(background_color)]; dropseed=true, lchoices=lchoices, + color = distinguishable_colors(1, [RGB(background_color)]; dropseed=true, + lchoices=lchoices, cchoices=[target.c], hchoices=[target.h]) return RGBAf(color[1], Makie.Colors.alpha(target_color)) end @@ -64,11 +65,13 @@ A series of XYVectors, or a single xyvector. const SeriesCurves = Union{XYVector,AbstractVector{<:XYVector}} function series_plot!(subfig::FigurePosition, per_class_pr_curves::SeriesCurves, - class_labels::Union{Nothing,AbstractVector{String}}; legend=:lt, title="No title", - xlabel="x label", ylabel="y label", solid_color=nothing, color=nothing, + class_labels::Union{Nothing,AbstractVector{String}}; legend=:lt, + title="No title", + xlabel="x label", ylabel="y label", solid_color=nothing, + color=nothing, linewidth=nothing, scatter=NamedTuple()) - - axis_theme = get_theme(subfig, :SeriesPlot, :Axis; title=title, titlealign=:left, xlabel=xlabel, + axis_theme = get_theme(subfig, :SeriesPlot, :Axis; title=title, titlealign=:left, + xlabel=xlabel, ylabel=ylabel, aspect=AxisAspect(1), xticks=0:0.2:1, yticks=0.2:0.2:1) @@ -79,7 +82,7 @@ function series_plot!(subfig::FigurePosition, per_class_pr_curves::SeriesCurves, isnothing(solid_color) || (series_theme[:solid_color] = solid_color) isnothing(color) || (series_theme[:color] = color) isnothing(linewidth) || (series_theme[:linewidth] = linewidth) - series_theme = merge(series_theme, Attributes(;scatter...)) + series_theme = merge(series_theme, Attributes(; scatter...)) hidedecorations!(ax; label=false, ticklabels=false, grid=false) limits!(ax, 0, 1, 0, 1) Makie.series!(ax, per_class_pr_curves; labels=class_labels, series_theme...) @@ -90,38 +93,47 @@ function series_plot!(subfig::FigurePosition, per_class_pr_curves::SeriesCurves, end function plot_pr_curves!(subfig::FigurePosition, per_class_pr_curves::SeriesCurves, - class_labels::Union{Nothing,AbstractVector{String}}; legend=:lt, title="PR curves", - xlabel="True positive rate", ylabel="Precision", scatter=NamedTuple(), + class_labels::Union{Nothing,AbstractVector{String}}; legend=:lt, + title="PR curves", + xlabel="True positive rate", ylabel="Precision", + scatter=NamedTuple(), solid_color=nothing) - return series_plot!(subfig, per_class_pr_curves, class_labels; legend=legend, title=title, xlabel=xlabel, + return series_plot!(subfig, per_class_pr_curves, class_labels; legend=legend, + title=title, xlabel=xlabel, ylabel=ylabel, scatter=scatter, solid_color=solid_color) end function plot_roc_curves!(subfig::FigurePosition, per_class_roc_curves::SeriesCurves, - per_class_roc_aucs::NumberVector, class_labels::AbstractVector{<:String}; + per_class_roc_aucs::NumberVector, + class_labels::AbstractVector{<:String}; legend=:rb, title="ROC curves", xlabel="False positive rate", ylabel="True positive rate") auc_labels = [@sprintf("%s (AUC: %.3f)", class, per_class_roc_aucs[i]) for (i, class) in enumerate(class_labels)] - return series_plot!(subfig, per_class_roc_curves, auc_labels; legend=legend, title=title, xlabel=xlabel, + return series_plot!(subfig, per_class_roc_curves, auc_labels; legend=legend, + title=title, xlabel=xlabel, ylabel=ylabel) end function plot_reliability_calibration_curves!(subfig::FigurePosition, per_class_reliability_calibration_curves::SeriesCurves, per_class_reliability_calibration_scores::NumberVector, - class_labels::AbstractVector{String}; legend=:rb) + class_labels::AbstractVector{String}; + legend=:rb) calibration_score_labels = map(enumerate(class_labels)) do (i, class) @sprintf("%s (MSE: %.3f)", class, per_class_reliability_calibration_scores[i]) end - scatter_theme = get_theme(subfig, :ReliabilityCalibrationCurves, :Scatter; marker=Circle, + scatter_theme = get_theme(subfig, :ReliabilityCalibrationCurves, :Scatter; + marker=Circle, markersize=5, strokewidth=0) - ideal_theme = get_theme(subfig, :ReliabilityCalibrationCurves, :Ideal; color=(:black, 0.5), + ideal_theme = get_theme(subfig, :ReliabilityCalibrationCurves, :Ideal; + color=(:black, 0.5), linestyle=:dash, linewidth=2) - ax = series_plot!(subfig, per_class_reliability_calibration_curves, calibration_score_labels; + ax = series_plot!(subfig, per_class_reliability_calibration_curves, + calibration_score_labels; legend=legend, title="Prediction reliability calibration", xlabel="Predicted probability bin", ylabel="Fraction of positives", scatter=scatter_theme) @@ -150,7 +162,8 @@ function plot_binary_discrimination_calibration_curves!(subfig::FigurePosition, discrimination_class::AbstractString; kw...) kw = values(kw) - scatter_theme = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, :Scatter; strokewidth=0) + scatter_theme = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, :Scatter; + strokewidth=0) # Hayaah, this theme merging is getting out of hand # but we want kw > BinaryDiscriminationCalibrationCurves > Scatter, so we need to somehow set things # after the theme merging above, especially, since we also pass those to series!, @@ -171,7 +184,8 @@ function plot_binary_discrimination_calibration_curves!(subfig::FigurePosition, per_expert...) end - calibration = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, :CalibrationCurve; + calibration = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, + :CalibrationCurve; solid_color=:navyblue, markerstrokewidth=0) set_from_kw!(calibration, :markersize, kw, 5) @@ -180,7 +194,8 @@ function plot_binary_discrimination_calibration_curves!(subfig::FigurePosition, Makie.series!(ax, calibration_curve; calibration...) - ideal_theme = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, :Ideal; color=(:black, 0.5), + ideal_theme = get_theme(subfig, :BinaryDiscriminationCalibrationCurves, :Ideal; + color=(:black, 0.5), linestyle=:dash) set_from_kw!(ideal_theme, :linewidth, kw, 2) linesegments!(ax, [0, 1], [0, 1]; label="Ideal", ideal_theme...) @@ -190,7 +205,8 @@ function plot_binary_discrimination_calibration_curves!(subfig::FigurePosition, end function plot_confusion_matrix!(subfig::FigurePosition, confusion::NumberMatrix, - class_labels::AbstractVector{String}, normalize_by::Union{Symbol,Nothing}=nothing; + class_labels::AbstractVector{String}, + normalize_by::Union{Symbol,Nothing}=nothing; annotation_text_size=20, colormap=:Blues) nclasses = length(class_labels) if size(confusion) != (nclasses, nclasses) @@ -220,10 +236,12 @@ function plot_confusion_matrix!(subfig::FigurePosition, confusion::NumberMatrix, ylims!(ax, nclasses + 0.5, 0.5) tightlimits!(ax) plot_bg_color = to_color(ax.backgroundcolor[]) - crange = isnothing(normalize_by) ? (0.0, maximum(filter(!isnan, confusion))) : (0.0, 1.0) + crange = isnothing(normalize_by) ? (0.0, maximum(filter(!isnan, confusion))) : + (0.0, 1.0) nan_color = to_color(heatmap_theme.nan_color[]) cmap = to_colormap(to_value(pop!(heatmap_theme, :colormap, colormap))) - heatmap!(ax, confusion'; colorrange=crange, colormap=cmap, nan_color=nan_color, heatmap_theme...) + heatmap!(ax, confusion'; colorrange=crange, colormap=cmap, nan_color=nan_color, + heatmap_theme...) text_color = to_color(to_value(pop!(text_theme, :color, :black))) function label_color(i, j) c = confusion[i, j] @@ -234,9 +252,11 @@ function plot_confusion_matrix!(subfig::FigurePosition, confusion::NumberMatrix, end return high_contrast(bg_color, text_color) end - annos = vec([(string(confusion[i, j]), Point2f(j, i)) for i in class_indices, j in class_indices]) + annos = vec([(string(confusion[i, j]), Point2f(j, i)) + for i in class_indices, j in class_indices]) colors = vec([label_color(i, j) for i in class_indices, j in class_indices]) - text!(ax, annos; align=(:center, :center), color=colors, textsize=annotation_text_size, text_theme...) + text!(ax, annos; align=(:center, :center), color=colors, textsize=annotation_text_size, + text_theme...) return ax end @@ -286,14 +306,16 @@ function plot_kappas!(subfig::FigurePosition, per_class_kappas::NumberVector, aligns, offsets, text_colors = text_attributes(per_class_kappas, 2, bar_colors, bg_color, text_color) barplot!(ax, per_class_kappas; direction=:x, color=bar_colors[2]) - text!(ax, annotations; align=aligns, offset=offsets, color=text_colors, text_theme...) + text!(ax, annotations; align=aligns, offset=offsets, color=text_colors, + text_theme...) else ax.title = "Inter-rater reliability" values = vcat(per_class_kappas, per_class_IRA_kappas) groups = vcat(fill(2, nclasses), fill(1, nclasses)) xvals = vcat(1:nclasses, 1:nclasses) cmap = bar_colors - bars = barplot!(ax, xvals, max.(0, values); dodge=groups, color=groups, direction=:x, + bars = barplot!(ax, xvals, max.(0, values); dodge=groups, color=groups, + direction=:x, colormap=cmap) # This is a bit hacky, but for now the easiest way to figure out the exact, dodged positions rectangles = bars.plots[][1][] @@ -319,16 +341,16 @@ end """ evaluation_metrics_plot(data::Dict; resolution=(1000, 1000), textsize=12) - evaluation_metrics_plot(row::EvaluationRow; kwargs...) + evaluation_metrics_plot(row::EvaluationV1; kwargs...) -Plot all evaluation metrics generated via [`evaluation_metrics_row`](@ref) and/or +Plot all evaluation metrics generated via [`evaluation_metrics_record`](@ref) and/or [`evaluation_metrics`](@ref) in a single image. """ function evaluation_metrics_plot(data::Dict; kwargs...) - return evaluation_metrics_plot(EvaluationRow(data); kwargs...) + return evaluation_metrics_plot(EvaluationV1(data); kwargs...) end -function evaluation_metrics_plot(row::EvaluationRow; resolution=(1000, 1000), +function evaluation_metrics_plot(row::EvaluationV1; resolution=(1000, 1000), textsize=12) fig = Figure(; resolution=resolution, Axis=(titlesize=17,)) @@ -379,7 +401,7 @@ function evaluation_metrics_plot(row::EvaluationRow; resolution=(1000, 1000), row.optimal_threshold, row.class_labels[row.optimal_threshold_class]) end - legend_plots = filter(Makie.MakieLayout.get_plots(ax)) do plot + legend_plots = filter(Makie.get_plots(ax)) do plot return haskey(plot, :label) end elements = map(legend_plots) do elem diff --git a/src/row.jl b/src/row.jl index 424b97d..a79e715 100644 --- a/src/row.jl +++ b/src/row.jl @@ -1,162 +1,133 @@ ##### -##### `EvaluationRow` +##### `EvaluationObject ##### # Arrow can't handle matrices---so when we write/read matrices, we have to pack and unpack them o_O # https://github.com/apache/arrow-julia/issues/125 vec_to_mat(mat::AbstractMatrix) = mat - function vec_to_mat(vec::AbstractVector) n = isqrt(length(vec)) return reshape(vec, n, n) end - vec_to_mat(x::Missing) = return missing -# Redefinition is workaround for https://github.com/beacon-biosignals/Legolas.jl/issues/9 -const EVALUATION_ROW_SCHEMA = Legolas.Schema("lighthouse.evaluation@1") +const GenericCurve = Tuple{Vector{Float64},Vector{Float64}} +@schema "lighthouse.evaluation" Evaluation +@version EvaluationV1 begin + class_labels::Union{Missing,Vector{String}} + confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix) + discrimination_calibration_curve::Union{Missing,GenericCurve} + discrimination_calibration_score::Union{Missing,Float64} + multiclass_IRA_kappas::Union{Missing,Float64} + multiclass_kappa::Union{Missing,Float64} + optimal_threshold::Union{Missing,Float64} + optimal_threshold_class::Union{Missing,Int64} + per_class_IRA_kappas::Union{Missing,Vector{Float64}} + per_class_kappas::Union{Missing,Vector{Float64}} + stratified_kappas::Union{Missing, + Vector{@NamedTuple{per_class::Vector{Float64}, + multiclass::Float64, + n::Int64}}} + per_class_pr_curves::Union{Missing,Vector{GenericCurve}} + per_class_reliability_calibration_curves::Union{Missing,Vector{GenericCurve}} + per_class_reliability_calibration_scores::Union{Missing,Vector{Float64}} + per_class_roc_aucs::Union{Missing,Vector{Float64}} + per_class_roc_curves::Union{Missing,Vector{GenericCurve}} + per_expert_discrimination_calibration_curves::Union{Missing,Vector{GenericCurve}} + per_expert_discrimination_calibration_scores::Union{Missing,Vector{Float64}} + spearman_correlation::Union{Missing, + @NamedTuple{ρ::Float64, # Note: is rho not 'p' 😢 + n::Int64, + ci_lower::Float64, + ci_upper::Float64}} + thresholds::Union{Missing,Vector{Float64}} +end + +""" + @version EvaluationV1 begin + class_labels::Union{Missing,Vector{String}} + confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix) + discrimination_calibration_curve::Union{Missing,GenericCurve} + discrimination_calibration_score::Union{Missing,Float64} + multiclass_IRA_kappas::Union{Missing,Float64} + multiclass_kappa::Union{Missing,Float64} + optimal_threshold::Union{Missing,Float64} + optimal_threshold_class::Union{Missing,Int64} + per_class_IRA_kappas::Union{Missing,Vector{Float64}} + per_class_kappas::Union{Missing,Vector{Float64}} + stratified_kappas::Union{Missing, + Vector{@NamedTuple{per_class::Vector{Float64}, + multiclass::Float64, + n::Int64}}} + per_class_pr_curves::Union{Missing,Vector{GenericCurve}} + per_class_reliability_calibration_curves::Union{Missing,Vector{GenericCurve}} + per_class_reliability_calibration_scores::Union{Missing,Vector{Float64}} + per_class_roc_aucs::Union{Missing,Vector{Float64}} + per_class_roc_curves::Union{Missing,Vector{GenericCurve}} + per_expert_discrimination_calibration_curves::Union{Missing,Vector{GenericCurve}} + per_expert_discrimination_calibration_scores::Union{Missing,Vector{Float64}} + spearman_correlation::Union{Missing, + @NamedTuple{ρ::Float64, # Note: is rho not 'p' 😢 + n::Int64, + ci_lower::Float64, + ci_upper::Float64}} + thresholds::Union{Missing,Vector{Float64}} + end + +A Legolas record representing the output metrics computed by +[`evaluation_metrics_record`](@ref) and [`evaluation_metrics`](@ref). +See [Legolas.jl](https://github.com/beacon-biosignals/Legolas.jl) for details regarding +Legolas record types. """ - const EvaluationRow = Legolas.@row("lighthouse.evaluation@1", - class_labels::Union{Missing,Vector{String}}, - confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix), - discrimination_calibration_curve::Union{Missing, - Tuple{Vector{Float64}, - Vector{Float64}}}, - discrimination_calibration_score::Union{Missing,Float64}, - multiclass_IRA_kappas::Union{Missing,Float64}, - multiclass_kappa::Union{Missing,Float64}, - optimal_threshold::Union{Missing,Float64}, - optimal_threshold_class::Union{Missing,Int64}, - per_class_IRA_kappas::Union{Missing,Vector{Float64}}, - per_class_kappas::Union{Missing,Vector{Float64}}, - stratified_kappas::Union{Missing, - Vector{NamedTuple{(:per_class, - :multiclass, - :n), - Tuple{Vector{Float64}, - Float64, - Int64}}}}, - per_class_pr_curves::Union{Missing, - Vector{Tuple{Vector{Float64}, - Vector{Float64}}}}, - per_class_reliability_calibration_curves::Union{Missing, - Vector{Tuple{Vector{Float64}, - Vector{Float64}}}}, - per_class_reliability_calibration_scores::Union{Missing, - Vector{Float64}}, - per_class_roc_aucs::Union{Missing,Vector{Float64}}, - per_class_roc_curves::Union{Missing, - Vector{Tuple{Vector{Float64}, - Vector{Float64}}}}, - per_expert_discrimination_calibration_curves::Union{Missing, - Vector{Tuple{Vector{Float64}, - Vector{Float64}}}}, - per_expert_discrimination_calibration_scores::Union{Missing, - Vector{Float64}}, - spearman_correlation::Union{Missing, - NamedTuple{(:ρ, :n, - :ci_lower, - :ci_upper), - Tuple{Float64, - Int64, - Float64, - Float64}}}, - thresholds::Union{Missing,Vector{Float64}}) - EvaluationRow(evaluation_row_dict::Dict{String, Any}) -> EvaluationRow - -A type alias for [`Legolas.Row{typeof(Legolas.Schema("lighthouse.evaluation@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing the output metrics computed by [`evaluation_metrics_row`](@ref) and -[`evaluation_metrics`](@ref). - -Constructor that takes `evaluation_row_dict` converts [`evaluation_metrics`](@ref) -`Dict` of metrics results (e.g. from Lighthouse EvaluationV1 + +`Dict` of metrics results (e.g. from Lighthouse v for (k, v) in pairs(evaluation_row_dict))...) - return EvaluationRow(row) +function EvaluationV1(d::Dict) + row = (; (Symbol(k) => v for (k, v) in pairs(d))...) + return EvaluationV1(row) end """ - _evaluation_row_dict(row::EvaluationRow) -> Dict{String,Any} + _evaluation_row_dict(row::EvaluationV1) -> Dict{String,Any} -Convert [`EvaluationRow`](@ref) into `::Dict{String, Any}` results, as are -output by `[`evaluation_metrics`](@ref)` (and predated use of `EvaluationRow` in +Convert [`EvaluationV1`](@ref) into `::Dict{String, Any}` results, as are +output by `[`evaluation_metrics`](@ref)` (and predated use of `EvaluationV1` in Lighthouse v for (k, v) in pairs(NamedTuple(row)) if !ismissing(v)) end ##### -##### `ObservationRow` +##### `Observation ##### -# Redefinition is workaround for https://github.com/beacon-biosignals/Legolas.jl/issues/9 -const OBSERVATION_ROW_SCHEMA = Legolas.Schema("lighthouse.observation@1") +@schema "lighthouse.observation" Observation +@version ObservationV1 begin + predicted_hard_label::Int64 + predicted_soft_labels::Vector{Float32} + elected_hard_label::Int64 + votes::Union{Missing,Vector{Int64}} +end + """ - const ObservationRow = Legolas.@row("lighthouse.observation@1", - predicted_hard_label::Int64, - predicted_soft_labels::Vector{Float32}, - elected_hard_label::Int64, - votes::Union{Missing,Vector{Int64}}) - -A type alias for [`Legolas.Row{typeof(Legolas.Schema("lighthouse.observation@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing the per-observation input values required to compute [`evaluation_metrics_row`](@ref). + @version ObservationV1 begin + predicted_hard_label::Int64 + predicted_soft_labels::Vector{Float32} + elected_hard_label::Int64 + votes::Union{Missing,Vector{Int64}} + end + +A Legolas record representing the per-observation input values required to compute +[`evaluation_metrics_record`](@ref). """ -const ObservationRow = Legolas.@row("lighthouse.observation@1", predicted_hard_label::Int64, - predicted_soft_labels::Vector{Float32}, - elected_hard_label::Int64, - votes::Union{Missing,Vector{Int64}}) +ObservationV1 # Convert vector of per-class soft label vectors to expected matrix format, e.g., # [[0.1, .2, .7], [0.8, .1, .1]] for 2 observations of 3-class classification returns @@ -169,7 +140,7 @@ function _predicted_soft_to_matrix(per_observation_soft_labels) end function _observation_table_to_inputs(observation_table) - Legolas.validate(observation_table, OBSERVATION_ROW_SCHEMA) + Legolas.validate(Tables.schema(observation_table), ObservationV1SchemaVersion()) df_table = Tables.columns(observation_table) votes = missing if any(ismissing, df_table.votes) && !all(ismissing, df_table.votes) @@ -188,7 +159,7 @@ function _inputs_to_observation_table(; predicted_hard_labels::AbstractVector, elected_hard_labels::AbstractVector, votes::Union{Nothing,Missing,AbstractMatrix}=nothing) votes_itr = has_value(votes) ? eachrow(votes) : - (missing for _ in 1:length(predicted_hard_labels)) + Iterators.repeated(missing, length(predicted_hard_labels)) predicted_soft_labels_itr = eachrow(predicted_soft_labels) if !(length(predicted_hard_labels) == length(predicted_soft_labels_itr) == @@ -200,15 +171,14 @@ function _inputs_to_observation_table(; predicted_hard_labels::AbstractVector, predicted_soft_labels_itr, votes_itr) do predicted_hard_label, elected_hard_label, predicted_soft_labels, votes - return ObservationRow(; predicted_hard_label, elected_hard_label, - predicted_soft_labels, votes) + return ObservationV1(; predicted_hard_label, elected_hard_label, + predicted_soft_labels, votes) end - Legolas.validate(observation_table, OBSERVATION_ROW_SCHEMA) return observation_table end ##### -##### Metrics rows +##### Metrics ##### """ @@ -252,111 +222,104 @@ const CURVE_ARROW_NAME = Symbol("JuliaLang.Lighthouse.Curve") ArrowTypes.arrowname(::Type{<:Curve}) = CURVE_ARROW_NAME ArrowTypes.JuliaType(::Val{CURVE_ARROW_NAME}) = Curve +@schema "lighthouse.class" Class +@version ClassV1 begin + class_index::Union{Int64,Symbol} = check_valid_class(class_index) + class_labels::Union{Missing,Vector{String}} +end + """ - const ClassRow = Legolas.@row("lighthouse.class@1", - class_labels::Union{Missing,Vector{String}}, - class_index::Union{Int64,Symbol}) + @version ClassV1 begin + class_index::Union{Int64,Symbol} = check_valid_class(class_index) + class_labels::Union{Missing,Vector{String}} + end -A type alias for [`Legolas.Row{typeof(Legolas.Schema("lighthouse.class@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing a single column `class_index` that holds either an integer or the value -`:multiclass`, and the class names associated to the integer class indices. +A Legolas record representing a single column `class_index` that holds either an integer or +the value `:multiclass`, and the class names associated to the integer class indices. """ -const ClassRow = Legolas.@row("lighthouse.class@1", - class_index::Union{Int64,Symbol} = check_valid_class(class_index), - class_labels::Union{Missing,Vector{String}} = coalesce(class_labels, - missing)) +ClassV1 check_valid_class(class_index::Integer) = Int64(class_index) function check_valid_class(class_index::Any) return class_index === :multiclass ? class_index : - throw(ArgumentError("Classes must be integers or the symbol `:multiclass`")) + throw(ArgumentError("Classes must be integer or the symbol `:multiclass`")) +end + +@schema "lighthouse.label-metrics" LabelMetrics +@version LabelMetricsV1 > ClassV1 begin + ira_kappa::Union{Missing,Float64} + per_expert_discrimination_calibration_curves::Union{Missing,Vector{Curve}} = lift(v -> Curve.(v), + per_expert_discrimination_calibration_curves) + per_expert_discrimination_calibration_scores::Union{Missing,Vector{Float64}} end """ - LabelMetricsRow = Legolas.@row("lighthouse.label-metrics@1" > "lighthouse.class@1", - ira_kappa::Union{Missing,Float64}, - per_expert_discrimination_calibration_curves::Union{Missing, - Vector{Curve}} = ismissing(per_expert_discrimination_calibration_curves) ? - missing : - Curve.(per_expert_discrimination_calibration_curves), - per_expert_discrimination_calibration_scores::Union{Missing, - Vector{Float64}}) - -A type alias for [`Legolas.Row{typeof(Legolas.Schema("label-metrics@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing metrics calculated over labels provided by multiple labelers. + @version LabelMetricsV1 > ClassV1 begin + ira_kappa::Union{Missing,Float64} + per_expert_discrimination_calibration_curves::Union{Missing,Vector{Curve}} = lift(v -> Curve.(v), + per_expert_discrimination_calibration_curves) + per_expert_discrimination_calibration_scores::Union{Missing,Vector{Float64}} + end + +A Legolas record representing metrics calculated over labels provided by multiple labelers. See also [`get_label_metrics_multirater`](@ref) and [`get_label_metrics_multirater_multiclass`](@ref). """ -const LabelMetricsRow = Legolas.@row("lighthouse.label-metrics@1" > "lighthouse.class@1", - ira_kappa::Union{Missing,Float64}, - per_expert_discrimination_calibration_curves::Union{Missing,Vector{Curve}} = ismissing(per_expert_discrimination_calibration_curves) ? - missing : - Curve.(per_expert_discrimination_calibration_curves), - per_expert_discrimination_calibration_scores::Union{Missing, - Vector{Float64}}) +LabelMetricsV1 + +@schema "lighthouse.hardened-metrics" HardenedMetrics +@version HardenedMetricsV1 > ClassV1 begin + confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix) + discrimination_calibration_curve::Union{Missing,Curve} = lift(Curve, + discrimination_calibration_curve) + discrimination_calibration_score::Union{Missing,Float64} + ea_kappa::Union{Missing,Float64} +end """ - HardenedMetricsRow = Legolas.@row("lighthouse.hardened-metrics@1" > - "lighthouse.class@1", - confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix), - discrimination_calibration_curve::Union{Missing,Curve} = ismissing(discrimination_calibration_curve) ? - missing : - Curve(discrimination_calibration_curve), - discrimination_calibration_score::Union{Missing,Float64}, - ea_kappa::Union{Missing,Float64}) - -A type alias for [`Legolas.Row{typeof(Legolas.Schema("hardened-metrics@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing metrics calculated over predicted hard labels. + @version HardenedMetricsV1 > ClassV1 begin + confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix) + discrimination_calibration_curve::Union{Missing,Curve} = lift(Curve, + discrimination_calibration_curve) + discrimination_calibration_score::Union{Missing,Float64} + ea_kappa::Union{Missing,Float64} + end + +A Legolas record representing metrics calculated over predicted hard labels. See also [`get_hardened_metrics`](@ref), [`get_hardened_metrics_multirater`](@ref), and [`get_hardened_metrics_multiclass`](@ref). """ -const HardenedMetricsRow = Legolas.@row("lighthouse.hardened-metrics@1" > - "lighthouse.class@1", - confusion_matrix::Union{Missing,Array{Int64}} = vec_to_mat(confusion_matrix), - discrimination_calibration_curve::Union{Missing,Curve} = ismissing(discrimination_calibration_curve) ? - missing : - Curve(discrimination_calibration_curve), - discrimination_calibration_score::Union{Missing, - Float64}, - ea_kappa::Union{Missing,Float64}) +HardenedMetricsV1 + +@schema "lighthouse.tradeoff-metrics" TradeoffMetrics +@version TradeoffMetricsV1 > ClassV1 begin + roc_curve::Curve = lift(Curve, roc_curve) + roc_auc::Float64 + pr_curve::Curve = lift(Curve, pr_curve) + spearman_correlation::Union{Missing,Float64} + spearman_correlation_ci_upper::Union{Missing,Float64} + spearman_correlation_ci_lower::Union{Missing,Float64} + n_samples::Union{Missing,Int} + reliability_calibration_curve::Union{Missing,Curve} = lift(Curve, + reliability_calibration_curve) + reliability_calibration_score::Union{Missing,Float64} +end """ - TradeoffMetricsRow = Legolas.@row("lighthouse.tradeoff-metrics@1" > - "lighthouse.class@1", - roc_curve::Curve = ismissing(roc_curve) ? - missing : Curve(roc_curve), - roc_auc::Float64, - pr_curve::Curve = ismissing(pr_curve) ? - missing : Curve(pr_curve), - spearman_correlation::Union{Missing, Float64}, - spearman_correlation_ci_upper::Union{Missing, Float64}, - spearman_correlation_ci_lower::Union{Missing, Float64}, - n_samples::Union{Missing,Int}, - reliability_calibration_curve::Union{Missing, - Curve} = ismissing(reliability_calibration_curve) ? - missing : - Curve(reliability_calibration_curve), - reliability_calibration_score::Union{Missing, Float64}) - -A type alias for [`Legolas.Row{typeof(Legolas.Schema("tradeoff-metrics@1"))}`](https://beacon-biosignals.github.io/Legolas.jl/stable/#Legolas.@row) -representing metrics calculated over predicted soft labels. -See also [`get_tradeoff_metrics`](@ref) and [`get_tradeoff_metrics_binary_multirater`](@ref). + @version TradeoffMetricsV1 > ClassV1 begin + roc_curve::Curve = lift(Curve, roc_curve) + roc_auc::Float64 + pr_curve::Curve = lift(Curve, pr_curve) + spearman_correlation::Union{Missing,Float64} + spearman_correlation_ci_upper::Union{Missing,Float64} + spearman_correlation_ci_lower::Union{Missing,Float64} + n_samples::Union{Missing,Int} + reliability_calibration_curve::Union{Missing,Curve} = lift(Curve, + reliability_calibration_curve) + reliability_calibration_score::Union{Missing,Float64} + end + +A Legolas record representing metrics calculated over predicted soft labels. See also +[`get_tradeoff_metrics`](@ref) and [`get_tradeoff_metrics_binary_multirater`](@ref). """ -const TradeoffMetricsRow = Legolas.@row("lighthouse.tradeoff-metrics@1" > - "lighthouse.class@1", - roc_curve::Curve = ismissing(roc_curve) ? missing : - Curve(roc_curve), - roc_auc::Float64, - pr_curve::Curve = ismissing(pr_curve) ? missing : - Curve(pr_curve), - spearman_correlation::Union{Missing,Float64}, - spearman_correlation_ci_upper::Union{Missing, - Float64}, - spearman_correlation_ci_lower::Union{Missing, - Float64}, - n_samples::Union{Missing,Int}, - reliability_calibration_curve::Union{Missing,Curve} = ismissing(reliability_calibration_curve) ? - missing : - Curve(reliability_calibration_curve), - reliability_calibration_score::Union{Missing, - Float64}) +TradeoffMetricsV1 diff --git a/test/deprecations.jl b/test/deprecations.jl new file mode 100644 index 0000000..1b86991 --- /dev/null +++ b/test/deprecations.jl @@ -0,0 +1,3 @@ +@testest "deprecations" begin + @test_throws ErrorException Lighthouse.evaluation_metrics_row() +end diff --git a/test/learn.jl b/test/learn.jl index 6cefc8a..b562ae5 100644 --- a/test/learn.jl +++ b/test/learn.jl @@ -22,6 +22,25 @@ function Lighthouse.loss_and_prediction(c::TestClassifier, dummy_input_batch) return c.dummy_loss, dummy_soft_label_batch end +@testset "`_values_or_missing`" begin + @test Lighthouse._values_or_missing(nothing) === missing + @test Lighthouse._values_or_missing(missing) === missing + @test Lighthouse._values_or_missing(1) === 1 + @test Lighthouse._values_or_missing([1, 2, 3]) == [1, 2, 3] + @test Lighthouse._values_or_missing([missing]) === missing + @test Lighthouse._values_or_missing([1, missing]) === missing + + input = Union{Int,Missing}[1, 2, 3] + result = Lighthouse._values_or_missing(input) + @test result == input + @test result isa Vector{Int} + + input = Union{Int,Missing}[1 2; 3 4] + result = Lighthouse._values_or_missing(input) + @test result == input + @test result isa Matrix{Int} +end + @testset "Multi-class learn!(::TestModel, ...)" begin mktempdir() do tmpdir model = TestClassifier(1000000.0, ["class_$i" for i in 1:5]) @@ -100,9 +119,9 @@ end @test length(logger.logged["wheeeeeee/metrics_for_all_time"]) == 1 # Test plotting with no votes directly with eval row - eval_row = Lighthouse.evaluation_metrics_row(predicted_hard, predicted_soft, - elected_hard, model.classes; - votes=nothing) + eval_row = Lighthouse.evaluation_metrics_record(predicted_hard, predicted_soft, + elected_hard, model.classes; + votes=nothing) all_together_no_ira = evaluation_metrics_plot(eval_row) @testplot all_together_no_ira @@ -180,7 +199,7 @@ end all_together_2 = evaluation_metrics_plot(plot_data) @testplot all_together_2 - all_together_3 = evaluation_metrics_plot(EvaluationRow(plot_data)) + all_together_3 = evaluation_metrics_plot(EvaluationV1(plot_data)) @testplot all_together_3 #savefig(all_together_2, "/tmp/multiclass.png") @@ -314,7 +333,7 @@ end # Test binary discrimination with no multiclass votes plot_data_1["per_expert_discrimination_calibration_curves"] = missing - no_expert_calibration = evaluation_metrics_plot(EvaluationRow(plot_data_1)) + no_expert_calibration = evaluation_metrics_plot(EvaluationV1(plot_data_1)) @testplot no_expert_calibration # Test that plotting succeeds (no specialization relative to the multi-class tests) diff --git a/test/metrics.jl b/test/metrics.jl index 1195985..e118010 100644 --- a/test/metrics.jl +++ b/test/metrics.jl @@ -212,7 +212,8 @@ end scaled_metrics = get_tradeoff_metrics(predicted_soft_labels, elected_hard_labels, i_class; thresholds=scaled_thresholds, - binarize=scaled_binarize_by_threshold, class_labels) + binarize=scaled_binarize_by_threshold, + class_labels) @test isequal(default_metrics, scaled_metrics) # Discrim calibration diff --git a/test/plotting.jl b/test/plotting.jl index 00e5458..947deaa 100644 --- a/test/plotting.jl +++ b/test/plotting.jl @@ -2,60 +2,53 @@ using Makie.Colors: Gray @testset "plotting" begin @testset "NaN color" begin - confusion = [ - NaN 0; - 1.0 0.5 - ] + confusion = [NaN 0; + 1.0 0.5] nan_confusion = plot_confusion_matrix(confusion, ["test1", "test2"], :Row) @testplot nan_confusion - nan_custom_confusion = with_theme(ConfusionMatrix = (Heatmap=(nan_color=:red,), Text=(color=:red,))) do - plot_confusion_matrix(confusion, ["test1", "test2"], :Row) + nan_custom_confusion = with_theme(; + ConfusionMatrix=(Heatmap=(nan_color=:red,), + Text=(color=:red,))) do + return plot_confusion_matrix(confusion, ["test1", "test2"], :Row) end @testplot nan_custom_confusion end @testset "Kappa placement" begin classes = ["class $i" for i in 1:5] - kappa_text_placement = with_theme(Kappas = (Text=(color=Gray(0.5),),)) do - plot_kappas((1:5) ./ 5 .- 0.1, classes, (1:5) ./ 5, color = [Gray(0.4), Gray(0.2)]) + kappa_text_placement = with_theme(; Kappas=(Text=(color=Gray(0.5),),)) do + return plot_kappas((1:5) ./ 5 .- 0.1, classes, (1:5) ./ 5; + color=[Gray(0.4), Gray(0.2)]) end @testplot kappa_text_placement - kappa_text_placement_single = with_theme(Kappas = (Text=(color=:red,),)) do - plot_kappas((1:5) ./ 5, classes, color = [Gray(0.4), Gray(0.2)]) + kappa_text_placement_single = with_theme(; Kappas=(Text=(color=:red,),)) do + return plot_kappas((1:5) ./ 5, classes; color=[Gray(0.4), Gray(0.2)]) end @testplot kappa_text_placement_single end @testset "binary discriminiation calibration curves" begin rng = StableRNG(22) - curves = [(LinRange(0, 1, 10), range(0, stop=i/2, length=10) .+ (randn(rng, 10) .* 0.1)) for i in -1:3] - binary_discrimination_calibration_curves_plot = with_theme( - BinaryDiscriminationCalibrationCurves = ( - Ideal = ( - linewidth = 3, - color = (:green, 0.5) - ), - CalibrationCurve = ( - solid_color = :green, - markersize = 50, # should be overwritten by user kw - linewidth = 5, - ), - PerExpert = ( - solid_color = :red, - linewidth=1 - ), - ) - ) do - Lighthouse.plot_binary_discrimination_calibration_curves( - curves[3], - rand(rng, 5), - curves[[1, 2, 4, 5]], - nothing, nothing, - "", - markersize=10 - ) + curves = [(LinRange(0, 1, 10), + range(0; stop=i / 2, length=10) .+ (randn(rng, 10) .* 0.1)) + for i in -1:3] + theme = (; Ideal=(linewidth=3, color=(:green, 0.5)), + CalibrationCurve=(solid_color=:green, + markersize=50, # should be overwritten by user kw + linewidth=5), + PerExpert=(solid_color=:red, linewidth=1)) + plot = with_theme(; BinaryDiscriminationCalibrationCurves=theme) do + return Lighthouse.plot_binary_discrimination_calibration_curves(curves[3], + rand(rng, 5), + curves[[1, 2, 4, + 5]], + nothing, + nothing, + ""; + markersize=10) end + binary_discrimination_calibration_curves_plot = plot @testplot binary_discrimination_calibration_curves_plot end end diff --git a/test/row.jl b/test/row.jl index c833d0e..94d9158 100644 --- a/test/row.jl +++ b/test/row.jl @@ -6,8 +6,8 @@ @test_throws DimensionMismatch Lighthouse.vec_to_mat(collect(1:6)) # Invalid dimensions end -@testset "`EvaluationRow` basics" begin - # Most EvaluationRow testing happens via the `test_evaluation_metrics_roundtrip` +@testset "`EvaluationV1` basics" begin + # Most Evaluation testing happens via the `test_evaluation_metrics_roundtrip` # in test/learn.jl # Roundtrip from dict @@ -29,14 +29,14 @@ function test_roundtrip_observation_table(; kwargs...) return table end -@testset "`ObservationRow`" begin +@testset "`ObservationV1`" begin # Multiclass num_observations = 100 classes = ["A", "B", "C", "D"] predicted_soft_labels = rand(StableRNG(22), Float32, num_observations, length(classes)) predicted_hard_labels = map(argmax, eachrow(predicted_soft_labels)) - # Single labeler: round-trip `ObservationRow``... + # Single labeler: round-trip `ObservationV1`... elected_hard_one_labeller = predicted_hard_labels[[1:50..., 1:50...]] # Force 50% TP overall votes = missing table = test_roundtrip_observation_table(; predicted_soft_labels, predicted_hard_labels, @@ -44,14 +44,14 @@ end votes) # ...and parity in evaluation_metrics calculation: - metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels, - predicted_soft_labels, - elected_hard_one_labeller, - classes; votes) - metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes) + metrics_from_inputs = Lighthouse.evaluation_metrics_record(predicted_hard_labels, + predicted_soft_labels, + elected_hard_one_labeller, + classes; votes) + metrics_from_table = Lighthouse.evaluation_metrics_record(table, classes) @test isequal(metrics_from_inputs, metrics_from_table) - # Multiple labelers: round-trip `ObservationRow``... + # Multiple labelers: round-trip `ObservationV1`... for num_voters in (1, 5) possible_vote_labels = collect(0:length(classes)) # vote 0 == "no vote" vote_rng = StableRNG(22) @@ -68,18 +68,19 @@ end votes) # ...is there parity in evaluation_metrics calculations? - metrics_from_inputs = Lighthouse.evaluation_metrics_row(predicted_hard_labels, - predicted_soft_labels, - elected_hard_multilabeller, - classes; votes) - metrics_from_table = Lighthouse.evaluation_metrics_row(table, classes) + metrics_from_inputs = Lighthouse.evaluation_metrics_record(predicted_hard_labels, + predicted_soft_labels, + elected_hard_multilabeller, + classes; votes) + metrics_from_table = Lighthouse.evaluation_metrics_record(table, classes) @test isequal(metrics_from_inputs, metrics_from_table) r_table = Lighthouse._inputs_to_observation_table(; predicted_soft_labels, predicted_hard_labels, elected_hard_labels=elected_hard_multilabeller, votes) - @test isnothing(Legolas.validate(r_table, Lighthouse.OBSERVATION_ROW_SCHEMA)) + @test Legolas.complies_with(Tables.schema(r_table), + Lighthouse.ObservationV1SchemaVersion()) # ...can we handle both dataframe input and more generic row iterators? df_table = DataFrame(r_table) @@ -104,18 +105,18 @@ end end end -@testset "`ClassRow" begin - @test isa(Lighthouse.ClassRow(; class_index=3, class_labels=missing).class_index, Int64) - @test isa(Lighthouse.ClassRow(; class_index=Int8(3), class_labels=missing).class_index, +@testset "`ClassV1" begin + @test isa(Lighthouse.ClassV1(; class_index=3, class_labels=missing).class_index, Int64) + @test isa(Lighthouse.ClassV1(; class_index=Int8(3), class_labels=missing).class_index, Int64) - @test Lighthouse.ClassRow(; class_index=:multiclass).class_index == :multiclass - @test Lighthouse.ClassRow(; class_index=:multiclass, - class_labels=["a", "b"]).class_labels == ["a", "b"] - - @test_throws ArgumentError Lighthouse.ClassRow(; class_index=3.0f0, - class_labels=missing) - @test_throws ArgumentError Lighthouse.ClassRow(; class_index=:mUlTiClAsS, - class_labels=missing) + @test Lighthouse.ClassV1(; class_index=:multiclass).class_index == :multiclass + @test Lighthouse.ClassV1(; class_index=:multiclass, + class_labels=["a", "b"]).class_labels == ["a", "b"] + + @test_throws ArgumentError Lighthouse.ClassV1(; class_index=3.0f0, + class_labels=missing) + @test_throws ArgumentError Lighthouse.ClassV1(; class_index=:mUlTiClAsS, + class_labels=missing) end @testset "class_labels" begin diff --git a/test/runtests.jl b/test/runtests.jl index f4e6c7e..c919808 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -28,20 +28,20 @@ macro testplot(fig_name) end end -const EVALUATION_ROW_KEYS = string.(keys(EvaluationRow())) +const EVALUATION_V1_KEYS = string.(fieldnames(EvaluationV1)) function test_evaluation_metrics_roundtrip(row_dict::Dict{String,S}) where {S} - # Make sure we're capturing all metrics keys in our Schema - keys_not_in_schema = setdiff(keys(row_dict), EVALUATION_ROW_KEYS) + # Make sure all metrics keys are captured in our schema and are not thrown away + keys_not_in_schema = setdiff(keys(row_dict), EVALUATION_V1_KEYS) @test isempty(keys_not_in_schema) # Do the roundtripping (will fail if schema types do not validate after roundtrip) - row = EvaluationRow(row_dict) - rt_row = roundtrip_row(row) + record = EvaluationV1(row_dict) + rt_row = roundtrip_row(record) # Make sure full row roundtrips correctly - @test issetequal(keys(row), keys(rt_row)) - for (k, v) in pairs(row) + @test issetequal(keys(record), keys(rt_row)) + for (k, v) in pairs(record) if ismissing(v) @test ismissing(rt_row[k]) else @@ -50,7 +50,7 @@ function test_evaluation_metrics_roundtrip(row_dict::Dict{String,S}) where {S} end # Make sure originating metrics dictionary roundtrips correctly - rt_dict = Lighthouse._evaluation_row_dict(rt_row) + rt_dict = Lighthouse._evaluation_dict(rt_row) for (k, v) in pairs(row_dict) if ismissing(v) @test ismissing(rt_dict[k]) @@ -61,11 +61,11 @@ function test_evaluation_metrics_roundtrip(row_dict::Dict{String,S}) where {S} return nothing end -function roundtrip_row(row::EvaluationRow) - p = mktempdir() * "rt_test.arrow" +function roundtrip_row(row::EvaluationV1) + io = IOBuffer() tbl = [row] - Legolas.write(p, tbl, Lighthouse.EVALUATION_ROW_SCHEMA) - return EvaluationRow(only(Tables.rows(Legolas.read(p)))) + Legolas.write(io, tbl, Lighthouse.EvaluationV1SchemaVersion()) + return EvaluationV1(only(Tables.rows(Legolas.read(seekstart(io))))) end include("plotting.jl")