Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Getting Dists from Distributions.jl #54

Merged
merged 10 commits into from
Apr 15, 2024
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -147,3 +147,5 @@ scratch/separable/path.jl

test/ABP.png
ABP.png

output
Empty file added output/.gitignore
Empty file.
2 changes: 1 addition & 1 deletion src/algorithms.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module Algorithms

using Distributions: Categorical
import Distributions

using ..Scruff
using ..Utils
Expand Down
3 changes: 1 addition & 2 deletions src/algorithms/sample_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import Base.length

using Distributions: Categorical
using StatsFuns: logsumexp
using ..SFuncs: Cat

Expand Down Expand Up @@ -92,7 +91,7 @@ end
function resample(ps::Particles, target_num_particles::Int = length(ps.samples))
lnws = normalize_weights(ps.log_weights)
weights = exp.(lnws)
selections = rand(Categorical(weights/sum(weights)), target_num_particles)
selections = rand(Distributions.Categorical(weights/sum(weights)), target_num_particles)
samples = map(selections) do ind
ps.samples[ind]
end
Expand Down
10 changes: 10 additions & 0 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,17 @@ module Operators

using ..Scruff

export
VectorOption,
Option

"""VectorOption{T} = Union{Vector{Union{}}, Vector{T}}"""
VectorOption{T} = Union{Vector{Union{}}, Vector{T}}
"""Option{T} = Union{Nothing, T}"""
Option{T} = Union{Nothing, T}

include("operators/op_performance.jl")
include("operators/op_defs.jl")
include("operators/op_impls.jl")

end
70 changes: 11 additions & 59 deletions src/operators/op_defs.jl
Original file line number Diff line number Diff line change
@@ -1,86 +1,38 @@
using ..MultiInterface

export
importance_sample,
support_quality_rank,
support_quality_from_rank,
VectorOption,
Option

# These are the new operator definitions where the type signature is specified

"""VectorOption{T} = Union{Vector{Union{}}, Vector{T}}"""
VectorOption{T} = Union{Vector{Union{}}, Vector{T}}
"""Option{T} = Union{Nothing, T}"""
Option{T} = Union{Nothing, T}

# to support
MultiInterface.get_imp(::Nothing, args...) = nothing

@interface forward(sf::SFunc{I,O}, i::I)::Dist{O} where {I,O}
@interface inverse(sf::SFunc{I,O}, o::O)::Score{I} where {I,O}
@interface is_deterministic(sf::SFunc)::Bool
@interface sample(sf::SFunc{I,O}, i::I)::O where {I,O}
@interface sample_logcpdf(sf::SFunc{I,O}, i::I)::Tuple{O, AbstractFloat} where {I,O}
@interface sample_logcpdf(sf::SFunc{I,O}, i::I)::Tuple{O, <:AbstractFloat} where {I,O}
# @interface invert(sf::SFunc{I,O}, o::O)::I where {I,O}
@interface lambda_msg(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, O})::SFunc{<:Option{Tuple{}}, I} where {I,O}
@interface marginalize(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, I})::SFunc{<:Option{Tuple{}}, O} where {I,O}
@interface marginalize(sfb::SFunc{X, Y}, sfa::SFunc{Y, Z})::SFunc{X, Z} where {X, Y, Z}
@interface logcpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
@interface cpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
@interface log_cond_prob_plus_c(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
@interface f_expectation(sf::SFunc{I,O}, i::I, fn::Function) where {I,O}
@interface expectation(sf::SFunc{I,O}, i::I)::O where {I,O}
# Expectation (and others) should either return some continuous relaxation of O (e.g. Ints -> Float) or there should be another op that does
@interface expectation(sf::SFunc{I,O}, i::I) where {I,O}
@interface variance(sf::SFunc{I,O}, i::I)::O where {I,O}
@interface get_score(sf::SFunc{Tuple{I},O}, i::I)::AbstractFloat where {I,O}
@interface get_log_score(sf::SFunc{Tuple{I},O}, i::I)::AbstractFloat where {I,O}

@impl begin
struct SFuncExpectation end

function expectation(sf::SFunc{I,O}, i::I) where {I,O}
return f_expectation(sf, i, x -> x)
end
end
# Return a new SFunc that is the result of summing samples from each constituent SFunc
@interface sumsfs(fs::NTuple{N, <:SFunc{I, O}})::SFunc{I, O} where {N, I, O}
@interface fit_mle(t::Type{S}, dat::SFunc{I, O})::S where {I, O, S <: SFunc{I, O}}
@interface support_minimum(sf::SFunc{I, O}, i::I)::O where {I, O}
@interface support_maximum(sf::SFunc{I, O}, i::I)::O where {I, O}

@interface support(sf::SFunc{I,O},
parranges::NTuple{N,Vector},
size::Integer,
curr::Vector{<:O}) where {I,O,N}

function importance_sample end

"""
support_quality_rank(sq::Symbol)

Convert the support quality symbol into an integer for comparison.
"""
function support_quality_rank(sq::Symbol)
if sq == :CompleteSupport return 3
elseif sq == :IncrementalSupport return 2
else return 1 end
end

"""
support_quality_from_rank(rank::Int)

Convert the rank back into the support quality.
"""
function support_quality_from_rank(rank::Int)
if rank == 3 return :CompleteSupport
elseif rank == 2 return :IncrementalSupport
else return :BestEffortSupport() end
end

@interface support_quality(sf::SFunc, parranges)

@impl begin
struct SFuncSupportQuality end

function support_quality(s::SFunc, parranges)
:BestEffortSupport
end
end

@interface bounded_probs(sf::SFunc{I,O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector})::Tuple{Vector{<:AbstractFloat}, Vector{<:AbstractFloat}} where {I,O,N}
Expand All @@ -91,7 +43,7 @@ end
id,
parids::Tuple)::Tuple{Vector{<:Scruff.Utils.Factor}, Vector{<:Scruff.Utils.Factor}} where {I,O,N}

#= Statistics computation not included in the release
#= Statistics computation not finished
@interface initial_stats(sf::SFunc)

# TODO create an abstract type Stats{I,O}
Expand All @@ -105,6 +57,7 @@ end
@interface accumulate_stats(sf::SFunc, existing_stats, new_stats)
@interface maximize_stats(sf::SFunc, stats)
=#

@interface compute_bel(sf::SFunc{I,O},
range::VectorOption{<:O},
pi::Dist{<:O},
Expand Down Expand Up @@ -142,4 +95,3 @@ end
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_idx::Integer)::Score where {N,I,O}

42 changes: 42 additions & 0 deletions src/operators/op_impls.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
export
support_quality_rank,
support_quality_from_rank

@impl begin
struct SFuncExpectation end

function expectation(sf::SFunc{I,O}, i::I) where {I,O}
return f_expectation(sf, i, x -> x)
end
end

"""
support_quality_rank(sq::Symbol)

Convert the support quality symbol into an integer for comparison.
"""
function support_quality_rank(sq::Symbol)
if sq == :CompleteSupport return 3
elseif sq == :IncrementalSupport return 2
else return 1 end
end

"""
support_quality_from_rank(rank::Int)

Convert the rank back into the support quality.
"""
function support_quality_from_rank(rank::Int)
if rank == 3 return :CompleteSupport
elseif rank == 2 return :IncrementalSupport
else return :BestEffortSupport() end
end

@impl begin
struct SFuncSupportQuality end

function support_quality(s::SFunc, parranges)
:BestEffortSupport
end
end

32 changes: 3 additions & 29 deletions src/sfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,40 +14,14 @@ macro impl(expr)
end

include("sfuncs/dist/dist.jl")
include("sfuncs/dist/cat.jl")
include("sfuncs/dist/constant.jl")
include("sfuncs/dist/flip.jl")
include("sfuncs/dist/normal.jl")
include("sfuncs/dist/uniform.jl")

include("sfuncs/score/score.jl")
include("sfuncs/score/hardscore.jl")
include("sfuncs/score/softscore.jl")
include("sfuncs/score/multiplescore.jl")
include("sfuncs/score/logscore.jl")
include("sfuncs/score/functionalscore.jl")
include("sfuncs/score/normalscore.jl")
include("sfuncs/score/parzen.jl")

include("sfuncs/util/extend.jl")

include("sfuncs/conddist/conditional.jl")
include("sfuncs/conddist/det.jl")
include("sfuncs/conddist/invertible.jl")
include("sfuncs/conddist/table.jl")
include("sfuncs/conddist/discretecpt.jl")
include("sfuncs/conddist/lineargaussian.jl")
include("sfuncs/conddist/CLG.jl")
include("sfuncs/conddist/separable.jl")
include("sfuncs/conddist/switch.jl")

include("sfuncs/compound/generate.jl")
include("sfuncs/compound/apply.jl")
include("sfuncs/compound/chain.jl")
include("sfuncs/compound/mixture.jl")
include("sfuncs/compound/network.jl")
include("sfuncs/compound/serial.jl")
include("sfuncs/compound/expander.jl")
include("sfuncs/conddist/conddist.jl")

include("sfuncs/compound/compound.jl")

include("sfuncs/op_impls/bp_ops.jl")
include("sfuncs/op_impls/basic_ops.jl")
Expand Down
8 changes: 8 additions & 0 deletions src/sfuncs/compound/compound.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
include("generate.jl")
include("apply.jl")
include("chain.jl")
include("mixture.jl")
include("network.jl")
include("serial.jl")
include("expander.jl")
include("sum.jl")
4 changes: 2 additions & 2 deletions src/sfuncs/compound/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,7 @@ end
struct MixtureSample end
function sample(sf::Mixture{I,O}, x::I)::O where {I,O}
probs = sf.probabilities/sum(sf.probabilities)
cat = Categorical(probs)
cat = Distributions.Categorical(probs)
which_component = rand(cat)
component = sf.components[which_component]
return sample(component, x)
Expand All @@ -222,7 +222,7 @@ end

function expectation(sf::Mixture{I,O}, x::I)::O where {I,O}
probs = sf.probabilities/sum(sf.probabilities)
cat = Categorical(probs)
cat = Distributions.Categorical(probs)
which_component = rand(cat)
component = sf.components[which_component]
return expectation(component, x)
Expand Down
17 changes: 17 additions & 0 deletions src/sfuncs/compound/sum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
struct SumSF{I, O, SFs <: NTuple{<:Number, <: SFunc{I, O}}} <: SFunc{I, O}
sfs::SFs
end

@impl begin
function sumsfs(fs::NTuple{N, <:SFunc}) where {N}
# Return an SFunc representing g(x) = f1(x) + f2(x) + ...
# I.e. convolution of the respective densities
return SumSF(fs)
end
end

@impl begin
function sample(sf::SumSF, x)
return sum(sample(sub_sf, x) for sub_sf in sf.sfs)
end
end
9 changes: 9 additions & 0 deletions src/sfuncs/conddist/conddist.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
include("conditional.jl")
include("det.jl")
include("invertible.jl")
include("table.jl")
include("discretecpt.jl")
include("lineargaussian.jl")
include("CLG.jl")
include("separable.jl")
include("switch.jl")
7 changes: 3 additions & 4 deletions src/sfuncs/conddist/lineargaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,8 @@ export LinearGaussian

See also: [`Conditional`](@ref), [`Normal`](@ref)
"""
mutable struct LinearGaussian{I <: Tuple{Vararg{Float64}}} <:
Conditional{I, Tuple{}, I, Float64, Normal}
sf :: Normal
mutable struct LinearGaussian{I <: Tuple{Vararg{Float64}}} <: Conditional{I, Tuple{}, I, Float64, Normal{Float64}}
sf :: Normal{Float64}
params :: Tuple{Tuple{Vararg{Float64}}, Float64, Float64}
"""
function LinearGaussian(weights :: Tuple{Vararg{Float64}}, bias :: Float64, sd :: Float64)
Expand Down Expand Up @@ -45,7 +44,7 @@ end
end
=#

function gensf(lg::LinearGaussian, inputs::Tuple{Vararg{Float64}})::Normal
function gensf(lg::LinearGaussian, inputs::Tuple{Vararg{Float64}})::Normal{Float64}
(weights, bias, sd) = lg.params
to_sum = inputs .* weights
mean = isempty(to_sum) ? bias : sum(to_sum) + bias
Expand Down
39 changes: 32 additions & 7 deletions src/sfuncs/dist/cat.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,33 @@
export Cat
export
Cat,
Categorical,
Discrete

import ..Utils.normalize
import Distributions.Categorical
import Distributions

const Categorical{P, Ps} = DistributionsSF{Distributions.Categorical{P, Ps}, Int}
Categorical(p::Ps) where {P, Ps <: AbstractVector{P}} = Categorical{Ps, P}(p)

const Discrete{T, P, Ts, Ps} = DistributionsSF{Distributions.DiscreteNonParametric{T, P, Ts, Ps}, T}
function Discrete(xs::Xs, ps::Ps) where {X, Xs <: AbstractVector{X}, P, Ps <: AbstractVector{P}}
# Handle duplicates
sort_order = sortperm(xs)
xs = xs[sort_order]
ps = ps[sort_order]

for i=1:(length(xs) - 1)
if xs[i] == xs[i + 1]
ps[i] += ps[i + 1]
ps[i + 1] = 0
end
end
keep = ps .> 0
xs = xs[keep]
ps = ps[keep]

return Discrete{X, P, Xs, Ps}(xs, ps)
end

@doc """
mutable struct Cat{O} <: Dist{O, Vector{Real}}
Expand Down Expand Up @@ -67,13 +93,12 @@ mutable struct Cat{O} <: Dist{O}
end
end


@impl begin
struct CatSupport end
function support(sf::Cat{O},
parranges::NTuple{N,Vector},
size::Integer,
curr::Vector{<:O}) where {O,N}
parranges::NTuple{N,Vector},
size::Integer,
curr::Vector{<:O}) where {O,N}
sf.range
end
end
Expand All @@ -88,7 +113,7 @@ end
@impl begin
struct CatSample end
function sample(sf::Cat{O}, i::Tuple{})::O where {O}
i = rand(Categorical(sf.params))
i = rand(Distributions.Categorical(sf.params))
return sf.range[i]
end
end
Expand Down
Loading
Loading