Skip to content

Commit

Permalink
Return NaN for negative ModeResult variance estimates
Browse files Browse the repository at this point in the history
  • Loading branch information
frankier committed Feb 10, 2025
1 parent 24d5556 commit 3ad099c
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 3 deletions.
47 changes: 44 additions & 3 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ using Printf: Printf
using ForwardDiff: ForwardDiff
using StatsAPI: StatsAPI
using Statistics: Statistics
using LinearAlgebra: LinearAlgebra

export maximum_a_posteriori, maximum_likelihood
# The MAP and MLE exports are only needed for the Optim.jl interface.
Expand Down Expand Up @@ -228,11 +229,47 @@ end

# Various StatsBase methods for ModeResult

function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
function StatsBase.coeftable(m::ModeResult; level::Real=0.95, numerrors_warnonly::Bool=true)
# Get columns for coeftable.
terms = string.(StatsBase.coefnames(m))
estimates = m.values.array[:, 1]
stderrors = StatsBase.stderror(m)
notes = nothing
local stderrors
if numerrors_warnonly
infmat = StatsBase.informationmatrix(m)
local vcov
try
vcov = inv(infmat)
catch e
if isa(e, LinearAlgebra.SingularException)
stderrors = fill(NaN, length(m.values))
notes = ["Info. matrix is singular" for _ in 1:length(m.values)]
else
rethrow(e)
end
else
stderrors = []
vars = LinearAlgebra.diag(vcov)
if any(x -> x < 0, vars)
notes = []
end
for var in vars
if var >= 0
push!(stderrors, sqrt(var))
if notes !== missing
push!(notes, "")
end
else
push!(stderrors, NaN)
if notes !== missing
push!(notes, "Negative variance")
end
end
end
end
else
stderrors = StatsBase.stderror(m)
end
zscore = estimates ./ stderrors
p = map(z -> StatsAPI.pvalue(Distributions.Normal(), z; tail=:both), zscore)

Expand All @@ -244,7 +281,7 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
level_ = 100 * level
level_percentage = isinteger(level_) ? Int(level_) : level_

cols = [estimates, stderrors, zscore, p, ci_low, ci_high]
cols = Vector[estimates, stderrors, zscore, p, ci_low, ci_high]
colnms = [
"Coef.",
"Std. Error",
Expand All @@ -253,6 +290,10 @@ function StatsBase.coeftable(m::ModeResult; level::Real=0.95)
"Lower $(level_percentage)%",
"Upper $(level_percentage)%",
]
if notes !== nothing
push!(cols, notes)
push!(colnms, "Error notes")
end
return StatsBase.CoefTable(cols, colnms, terms)
end

Expand Down
18 changes: 18 additions & 0 deletions test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -635,6 +635,24 @@ using Turing
maximum_a_posteriori(m; adtype=adbackend)
end
end

@testset "Collinear coeftable"
xs = [-1.0, 0.0, 1.0]
ys = [0.0, 0.0, 0.0]

@model function collinear(x, y)
a ~ Normal(0, 1)
b ~ Normal(0, 1)
y ~ MvNormal(a .* x .+ b .* x, 1)
end

model = collinear(xs, ys)
mle_estimate = Turing.Optimisation.estimate_mode(model, MLE())
tab = coeftable(mle_estimate)
@assert isnan(tab.cols[2][1])
@assert tab.colnms[end] == "Error notes"
@assert occursin("singular", tab.cols[end][1])
end
end

end

0 comments on commit 3ad099c

Please sign in to comment.