diff --git a/src/optimisation/Optimisation.jl b/src/optimisation/Optimisation.jl index 682e664a6..82bda7da7 100644 --- a/src/optimisation/Optimisation.jl +++ b/src/optimisation/Optimisation.jl @@ -273,6 +273,37 @@ StatsBase.params(m::ModeResult) = StatsBase.coefnames(m) StatsBase.vcov(m::ModeResult) = inv(StatsBase.informationmatrix(m)) StatsBase.loglikelihood(m::ModeResult) = m.lp +""" + Base.get(m::ModeResult, var_symbol::Symbol) + Base.get(m::ModeResult, var_symbols) + +Return the values of all the variables with the symbol(s) `var_symbol` in the mode result +`m`. The return value is a `NamedTuple` with `var_symbols` as the key(s). The second +argument should be either a `Symbol` or an iterator of `Symbol`s. +""" +function Base.get(m::ModeResult, var_symbols) + log_density = m.f + # Get all the variable names in the model. This is the same as the list of keys in + # m.values, but they are more convenient to filter when they are VarNames rather than + # Symbols. + varnames = collect( + map(first, Turing.Inference.getparams(log_density.model, log_density.varinfo)) + ) + # For each symbol s in var_symbols, pick all the values from m.values for which the + # variable name has that symbol. + et = eltype(m.values) + value_vectors = Vector{et}[] + for s in var_symbols + push!( + value_vectors, + [m.values[Symbol(vn)] for vn in varnames if DynamicPPL.getsym(vn) == s], + ) + end + return (; zip(var_symbols, value_vectors)...) +end + +Base.get(m::ModeResult, var_symbol::Symbol) = get(m, (var_symbol,)) + """ ModeResult(log_density::OptimLogDensity, solution::SciMLBase.OptimizationSolution) diff --git a/test/optimisation/Optimisation.jl b/test/optimisation/Optimisation.jl index 5e6144e57..76d3a940d 100644 --- a/test/optimisation/Optimisation.jl +++ b/test/optimisation/Optimisation.jl @@ -4,13 +4,14 @@ using ..Models: gdemo, gdemo_default using Distributions using Distributions.FillArrays: Zeros using DynamicPPL: DynamicPPL -using LinearAlgebra: I +using LinearAlgebra: Diagonal, I using Random: Random using Optimization using Optimization: Optimization using OptimizationBBO: OptimizationBBO using OptimizationNLopt: OptimizationNLopt using OptimizationOptimJL: OptimizationOptimJL +using ReverseDiff: ReverseDiff using StatsBase: StatsBase using StatsBase: coef, coefnames, coeftable, informationmatrix, stderror, vcov using Test: @test, @testset, @test_throws @@ -591,6 +592,30 @@ using Turing @test result.values[:x] ≈ 0 atol = 1e-1 @test result.values[:y] ≈ 100 atol = 1e-1 end + + @testset "get ModeResult" begin + @model function demo_model(N) + half_N = N ÷ 2 + a ~ arraydist(LogNormal.(fill(0, half_N), 1)) + b ~ arraydist(LogNormal.(fill(0, N - half_N), 1)) + covariance_matrix = Diagonal(vcat(a, b)) + x ~ MvNormal(covariance_matrix) + return nothing + end + + N = 12 + m = demo_model(N) | (x=randn(N),) + result = maximum_a_posteriori(m) + get_a = get(result, :a) + get_b = get(result, :b) + get_ab = get(result, [:a, :b]) + @assert keys(get_a) == (:a,) + @assert keys(get_b) == (:b,) + @assert keys(get_ab) == (:a, :b) + @assert get_b[:b] == get_ab[:b] + @assert vcat(get_a[:a], get_b[:b]) == result.values.array + @assert get(result, :c) == (; :c => Array{Float64}[]) + end end end