Skip to content

Commit

Permalink
Add Base.get method for ModeResult (#2269)
Browse files Browse the repository at this point in the history
* Add Base.get method for ModeResult

* Make get(::ModeResult, itr) work for any iterator

* Fix array type in get(::ModeResult, ...)
  • Loading branch information
mhauru authored Jun 23, 2024
1 parent fcb4ca7 commit 7b2869f
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 1 deletion.
31 changes: 31 additions & 0 deletions src/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,37 @@ StatsBase.params(m::ModeResult) = StatsBase.coefnames(m)
StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m))
StatsBase.loglikelihood(m::ModeResult) = m.lp

"""
Base.get(m::ModeResult, var_symbol::Symbol)
Base.get(m::ModeResult, var_symbols)
Return the values of all the variables with the symbol(s) `var_symbol` in the mode result
`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second
argument should be either a `Symbol` or an iterator of `Symbol`s.
"""
function Base.get(m::ModeResult, var_symbols)
log_density = m.f
# Get all the variable names in the model. This is the same as the list of keys in
# m.values, but they are more convenient to filter when they are VarNames rather than
# Symbols.
varnames = collect(
map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo))
)
# For each symbol s in var_symbols, pick all the values from m.values for which the
# variable name has that symbol.
et = eltype(m.values)
value_vectors = Vector{et}[]
for s in var_symbols
push!(
value_vectors,
[m.values[Symbol(vn)] for vn in varnames if DynamicPPL.getsym(vn) == s],
)
end
return (; zip(var_symbols, value_vectors)...)
end

Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,))

"""
ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution)
Expand Down
27 changes: 26 additions & 1 deletion test/optimisation/Optimisation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,14 @@ using ..Models: gdemo, gdemo_default
using Distributions
using Distributions.FillArrays: Zeros
using DynamicPPL: DynamicPPL
using LinearAlgebra: I
using LinearAlgebra: Diagonal, I
using Random: Random
using Optimization
using Optimization: Optimization
using OptimizationBBO: OptimizationBBO
using OptimizationNLopt: OptimizationNLopt
using OptimizationOptimJL: OptimizationOptimJL
using ReverseDiff: ReverseDiff
using StatsBase: StatsBase
using StatsBase: coef, coefnames, coeftable, informationmatrix, stderror, vcov
using Test: @test, @testset, @test_throws
Expand Down Expand Up @@ -591,6 +592,30 @@ using Turing
@test result.values[:x] 0 atol = 1e-1
@test result.values[:y] 100 atol = 1e-1
end

@testset "get ModeResult" begin
@model function demo_model(N)
half_N = N ÷ 2
a ~ arraydist(LogNormal.(fill(0, half_N), 1))
b ~ arraydist(LogNormal.(fill(0, N - half_N), 1))
covariance_matrix = Diagonal(vcat(a, b))
x ~ MvNormal(covariance_matrix)
return nothing
end

N = 12
m = demo_model(N) | (x=randn(N),)
result = maximum_a_posteriori(m)
get_a = get(result, :a)
get_b = get(result, :b)
get_ab = get(result, [:a, :b])
@assert keys(get_a) == (:a,)
@assert keys(get_b) == (:b,)
@assert keys(get_ab) == (:a, :b)
@assert get_b[:b] == get_ab[:b]
@assert vcat(get_a[:a], get_b[:b]) == result.values.array
@assert get(result, :c) == (; :c => Array{Float64}[])
end
end

end

0 comments on commit 7b2869f

Please sign in to comment.