Skip to content

Commit

Permalink
Add non-normalized confusion matrix option (#84)
Browse files Browse the repository at this point in the history
* add non-normalized confusion matrix option to lighthouse

* fix heatmap, docstring

* welp inverted

* clean up docs
  • Loading branch information
hannahilea authored Jul 1, 2022
1 parent 0540cdd commit bcb8a20
Show file tree
Hide file tree
Showing 3 changed files with 31 additions and 19 deletions.
4 changes: 2 additions & 2 deletions docs/src/plotting.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ground_truth = [1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4]
predicted_labels = [1, 1, 1, 1, 2, 2, 4, 4, 4, 4, 4, 3]
confusion = Lighthouse.confusion_matrix(length(classes), zip(predicted_labels, ground_truth))
fig, ax, p = plot_confusion_matrix(confusion, classes, :Row)
fig, ax, p = plot_confusion_matrix(confusion, classes)
```

```@example 1
Expand Down Expand Up @@ -278,4 +278,4 @@ Plots can also be generated directly from an `EvaluationRow`:
```@example 1
data_row = EvaluationRow(data)
evaluation_metrics_plot(data_row)
```
```
39 changes: 22 additions & 17 deletions src/plotting.jl
Original file line number Diff line number Diff line change
Expand Up @@ -180,22 +180,25 @@ function plot_binary_discrimination_calibration_curves!(subfig::FigurePosition,
end

function plot_confusion_matrix!(subfig::FigurePosition, confusion::NumberMatrix,
class_labels::AbstractVector{String}, normalize_by::Symbol;
class_labels::AbstractVector{String}, normalize_by::Union{Symbol,Nothing}=nothing;
annotation_text_size=20, colormap=:Blues)
normdim = get((Row=2, Column=1), normalize_by) do
return error("normalize_by must be either :Row or :Column, found: $(normalize_by)")
end

nclasses = length(class_labels)
if size(confusion) != (nclasses, nclasses)
error("Labels must match size of square confusion matrix. Found $(nclasses) labels for an $(size(confusion)) matrix")
throw(ArgumentError("Labels must match size of square confusion matrix. Found $(nclasses) labels for an $(size(confusion)) matrix"))
end
title = "Confusion Matrix"
if !isnothing(normalize_by)
normdim = get((Row=2, Column=1), normalize_by) do
throw(ArgumentError("normalize_by must be :Row, :Column, or `nothing`; found: $(normalize_by)"))
end
confusion = round.(confusion ./ sum(confusion; dims=normdim); digits=3)
title = "$(string(normalize_by))-Normalized Confusion"
end
confusion = round.(confusion ./ sum(confusion; dims=normdim); digits=3)
class_indices = 1:nclasses
text_theme = get_theme(subfig, :ConfusionMatrix, :Text; textsize=annotation_text_size)
heatmap_theme = get_theme(subfig, :ConfusionMatrix, :Heatmap; nan_color=(:black, 0.0))
axis_theme = get_theme(subfig, :ConfusionMatrix, :Axis; xticklabelrotation=pi / 4,
titlealign=:left, title="$(string(normalize_by))-Normalized Confusion",
titlealign=:left, title,
xlabel="Elected Class", ylabel="Predicted Class",
xticks=(class_indices, class_labels),
yticks=(class_indices, class_labels),
Expand All @@ -207,7 +210,7 @@ function plot_confusion_matrix!(subfig::FigurePosition, confusion::NumberMatrix,
ylims!(ax, nclasses + 0.5, 0.5)
tightlimits!(ax)
plot_bg_color = to_color(ax.backgroundcolor[])
crange = (0.0, 1.0)
crange = isnothing(normalize_by) ? (0.0, maximum(filter(!isnan, confusion))) : (0.0, 1.0)
nan_color = to_color(heatmap_theme.nan_color[])
cmap = to_colormap(to_value(pop!(heatmap_theme, :colormap, colormap)))
heatmap!(ax, confusion'; colorrange=crange, colormap=cmap, nan_color=nan_color, heatmap_theme...)
Expand Down Expand Up @@ -404,22 +407,24 @@ end
"""
plot_confusion_matrix!(subfig::FigurePosition, args...; kw...)
plot_confusion_matrix(confusion::AbstractMatrix{<: Number}, class_labels::AbstractVector{String}, normalize_by::Symbol;
resolution=(800,600),
annotation_text_size=20)
plot_confusion_matrix(confusion::AbstractMatrix{<: Number},
class_labels::AbstractVector{String},
normalize_by::Union{Symbol,Nothing}=nothing;
resolution=(800,600), annotation_text_size=20)
Lighthouse plots confusion matrices, which are simple tables
showing the empirical distribution of predicted class (the rows)
versus the elected class (the columns). These come in two variants:
versus the elected class (the columns). These can optionally be normalized:
* row-normalized: this means each row has been normalized to sum to 1. Thus, the row-normalized confusion matrix shows the empirical distribution of elected classes for a given predicted class. E.g. the first row of the row-normalized confusion matrix shows the empirical probabilities of the elected classes for a sample which was predicted to be in the first class.
* column-normalized: this means each column has been normalized to sum to 1. Thus, the column-normalized confusion matrix shows the empirical distribution of predicted classes for a given elected class. E.g. the first column of the column-normalized confusion matrix shows the empirical probabilities of the predicted classes for a sample which was elected to be in the first class.
* row-normalized (`:Row`): this means each row has been normalized to sum to 1. Thus, the row-normalized confusion matrix shows the empirical distribution of elected classes for a given predicted class. E.g. the first row of the row-normalized confusion matrix shows the empirical probabilities of the elected classes for a sample which was predicted to be in the first class.
* column-normalized (`:Column`): this means each column has been normalized to sum to 1. Thus, the column-normalized confusion matrix shows the empirical distribution of predicted classes for a given elected class. E.g. the first column of the column-normalized confusion matrix shows the empirical probabilities of the predicted classes for a sample which was elected to be in the first class.
```
fig, ax, p = plot_confusion_matrix(rand(2, 2), ["1", "2"], :Row)
fig, ax, p = plot_confusion_matrix(rand(2, 2), ["1", "2"])
fig = Figure()
ax = plot_confusion_matrix!(fig[1, 1], rand(2, 2), ["1", "2"], :Column)
ax = plot_confusion_matrix!(fig[1, 1], rand(2, 2), ["1", "2"], :Row)
ax = plot_confusion_matrix!(fig[1, 2], rand(2, 2), ["1", "2"], :Column)
```
"""
plot_confusion_matrix(args...; kw...) = axisplot(plot_confusion_matrix!, args; kw...)
Expand Down
7 changes: 7 additions & 0 deletions test/learn.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,13 @@ end
plot_data["class_labels"], :Column)
@testplot confusion_col

confusion_basic = plot_confusion_matrix(plot_data["confusion_matrix"],
plot_data["class_labels"])
@testplot confusion_basic

@test_throws ArgumentError plot_confusion_matrix(plot_data["confusion_matrix"],
plot_data["class_labels"], :norm)

all_together_2 = evaluation_metrics_plot(plot_data)
@testplot all_together_2

Expand Down

2 comments on commit bcb8a20

@hannahilea
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Error while trying to register: "Tag with name v0.14.10 already exists and points to a different commit"

Please sign in to comment.