Skip to content

Commit

Permalink
Merge pull request #53 from charles-river-analytics/MultiInterfaceRej…
Browse files Browse the repository at this point in the history
…uvenatePart2

MultiInterface Rejuvenate Part 2
  • Loading branch information
mharradon authored Apr 12, 2024
2 parents bed4ae8 + 1811deb commit 4989958
Show file tree
Hide file tree
Showing 17 changed files with 78 additions and 128 deletions.
77 changes: 14 additions & 63 deletions src/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ The `Operators` module defines the following interfaces for the following operat
- `sample(sf::SFunc{I,O}, i::I)::O where {I,O}`
- `sample_logcpdf(sf::SFunc{I,O}, i::I)::Tuple{O, AbstractFloat} where {I,O}`
- `invert(sf::SFunc{I,O}, o::O)::I where {I,O}`
- `lambda_msg(sf::SFunc{I,O}, i::SFunc{<:__Opt{Tuple{}}, O})::SFunc{<:__Opt{Tuple{}}, I} where {I,O}`
- `marginalize(sf::SFunc{I,O}, i::SFunc{<:__Opt{Tuple{}}, I})::SFunc{<:__Opt{Tuple{}}, O} where {I,O}`
- `lambda_msg(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, O})::SFunc{<:Option{Tuple{}}, I} where {I,O}`
- `marginalize(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, I})::SFunc{<:Option{Tuple{}}, O} where {I,O}`
- `logcpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}`
- `cpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}`
- `log_cond_prob_plus_c(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}`
Expand All @@ -23,104 +23,55 @@ The `Operators` module defines the following interfaces for the following operat
curr::Vector{<:O}) where {I,O,N}```
- `support_quality(sf::SFunc, parranges)`
- ```bounded_probs(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector})::Tuple{Vector{<:AbstractFloat}, Vector{<:AbstractFloat}} where {I,O,N}```
- ```make_factors(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
id,
parids::Tuple)::Tuple{Vector{<:Scruff.Utils.Factor}, Vector{<:Scruff.Utils.Factor}} where {I,O,N}```
- `initial_stats(sf::SFunc)`
- ```expected_stats(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
pis::NTuple{M,Dist},
child_lambda::Score{<:O}) where {I,O,N,M}```
- `accumulate_stats(sf::SFunc, existing_stats, new_stats)`
- `maximize_stats(sf::SFunc, stats)`
- ```compute_bel(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
pi::Dist{<:O},
lambda::Score{<:O})::Dist{<:O} where {I,O}```
- `compute_lambda(sf::SFunc, range::__OptVec, lambda_msgs::Vector{<:Score})::Score`
- `compute_lambda(sf::SFunc, range::VectorOption, lambda_msgs::Vector{<:Score})::Score`
- ```send_pi(sf::SFunc{I,O},
range::__OptVec{O},
range::VectorOption{O},
bel::Dist{O},
lambda_msg::Score{O})::Dist{<:O} where {I,O}```
- ```outgoing_pis(sf::SFunc,
range::__OptVec,
range::VectorOption,
bel::Dist,
incoming_lambdas::__OptVec{<:Score})::Vector{<:Dist}```
incoming_lambdas::VectorOption{<:Score})::Vector{<:Dist}```
- ```outgoing_lambdas(sf::SFunc{I,O},
lambda::Score{O},
range::__OptVec{O},
range::VectorOption{O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Vector{<:Score} where {N,I,O}```
- ```compute_pi(sf::SFunc{I,O},
range::__OptVec{O},
range::VectorOption{O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,I,O}```
- ```send_lambda(sf::SFunc{I,O},
lambda::Score{O},
range::__OptVec{O},
range::VectorOption{O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_idx::Integer)::Score where {N,I,O}```
"""
module Operators

using ...Scruff
using ..Scruff

include("operators/op_performance.jl")
include("operators/op_defs.jl")



"""
module_functions(mod)
Returns the name of all the functions in the given module.
"""
function module_functions(mod)
list = Symbol[]
for nm in names(mod; all=true)
if !startswith(string(nm), r"@|#") && match(r"^(?:eval|include)$", string(nm)) === nothing
typeof(eval(nm)) <: Function && push!(list,nm)
end
end
return list
end

Op = @__MODULE__

"""
export_operators()
Exports all the functions defined in Operators.
"""
function export_operators()
is = "export " * join(module_functions(Op), ", ")
eval(Meta.parse(is))
end

function module_name_string(fullname)
strs = [string(x) * "." for x in fullname[1:length(fullname)-1]]
join(strs) * string(fullname[length(fullname)])
end


"""
import_operators()
Imports all the functions defined in Operators
"""
macro import_operators()
is = "import " * module_name_string(fullname(Op)) * ": " * string(join(module_functions(Op), ", "))
quote
Base.eval(@__MODULE__, Meta.parse($(is)))
end
end

export_operators()

end
39 changes: 19 additions & 20 deletions src/operators/op_defs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,17 @@ using ..MultiInterface

export
importance_sample,
support_quality,
support_quality_rank,
support_quality_from_rank,
__OptVec,
__Opt
VectorOption,
Option

# These are the new operator definitions where the type signature is specified

"""__OptVec{T} = Union{Vector{Union{}}, Vector{T}}"""
__OptVec{T} = Union{Vector{Union{}}, Vector{T}}
"""__Opt{T} = Union{Nothing, T}"""
__Opt{T} = Union{Nothing, T}
"""VectorOption{T} = Union{Vector{Union{}}, Vector{T}}"""
VectorOption{T} = Union{Vector{Union{}}, Vector{T}}
"""Option{T} = Union{Nothing, T}"""
Option{T} = Union{Nothing, T}

# to support
MultiInterface.get_imp(::Nothing, args...) = nothing
Expand All @@ -24,8 +23,8 @@ MultiInterface.get_imp(::Nothing, args...) = nothing
@interface sample(sf::SFunc{I,O}, i::I)::O where {I,O}
@interface sample_logcpdf(sf::SFunc{I,O}, i::I)::Tuple{O, AbstractFloat} where {I,O}
# @interface invert(sf::SFunc{I,O}, o::O)::I where {I,O}
@interface lambda_msg(sf::SFunc{I,O}, i::SFunc{<:__Opt{Tuple{}}, O})::SFunc{<:__Opt{Tuple{}}, I} where {I,O}
@interface marginalize(sf::SFunc{I,O}, i::SFunc{<:__Opt{Tuple{}}, I})::SFunc{<:__Opt{Tuple{}}, O} where {I,O}
@interface lambda_msg(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, O})::SFunc{<:Option{Tuple{}}, I} where {I,O}
@interface marginalize(sf::SFunc{I,O}, i::SFunc{<:Option{Tuple{}}, I})::SFunc{<:Option{Tuple{}}, O} where {I,O}
@interface logcpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
@interface cpdf(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
@interface log_cond_prob_plus_c(sf::SFunc{I,O}, i::I, o::O)::AbstractFloat where {I,O}
Expand Down Expand Up @@ -83,11 +82,11 @@ end
end

@interface bounded_probs(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector})::Tuple{Vector{<:AbstractFloat}, Vector{<:AbstractFloat}} where {I,O,N}

@interface make_factors(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
id,
parids::Tuple)::Tuple{Vector{<:Scruff.Utils.Factor}, Vector{<:Scruff.Utils.Factor}} where {I,O,N}
Expand All @@ -98,7 +97,7 @@ end
# TODO create an abstract type Stats{I,O}
# (range, parranges, pi's, lambda's)
@interface expected_stats(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
pis::NTuple{M,Dist},
child_lambda::Score{<:O}) where {I,O,N,M}
Expand All @@ -107,39 +106,39 @@ end
@interface maximize_stats(sf::SFunc, stats)
=#
@interface compute_bel(sf::SFunc{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
pi::Dist{<:O},
lambda::Score)::Dist{<:O} where {I,O}

@interface compute_lambda(sf::SFunc,
range::__OptVec,
range::VectorOption,
lambda_msgs::Vector{<:Score})::Score

@interface send_pi(sf::SFunc{I,O},
range::__OptVec{O},
range::VectorOption{O},
bel::Dist{O},
lambda_msg::Score)::Dist{<:O} where {I,O}

@interface outgoing_pis(sf::SFunc,
range::__OptVec,
range::VectorOption,
bel::Dist,
incoming_lambdas::__OptVec{<:Score})::Vector{<:Dist}
incoming_lambdas::VectorOption{<:Score})::Vector{<:Dist}

@interface outgoing_lambdas(sf::SFunc{I,O},
lambda::Score,
range::__OptVec,
range::VectorOption,
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Vector{<:Score} where {N,I,O}

@interface compute_pi(sf::SFunc{I,O},
range::__OptVec{O},
range::VectorOption{O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,I,O}


@interface send_lambda(sf::SFunc{I,O},
lambda::Score,
range::__OptVec,
range::VectorOption,
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_idx::Integer)::Score where {N,I,O}
Expand Down
2 changes: 1 addition & 1 deletion src/runtime_utils.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
module RTUtils

using ...Scruff
using ..Scruff
using ..Utils
using ..SFuncs

Expand Down
5 changes: 2 additions & 3 deletions src/sfuncs.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,11 @@ module SFuncs
using Base: reinit_stdio
using ..MultiInterface

using ...Scruff
using ..Scruff
using ..Utils
using ..Operators
import ..Operators
Operators.@import_operators()
import ..Operators: __OptVec, Support, SupportQuality
import ..Operators: VectorOption, Support, SupportQuality

macro impl(expr)
return esc(MultiInterface.impl(__module__, __source__, expr, Operators))
Expand Down
4 changes: 2 additions & 2 deletions src/sfuncs/compound/apply.jl
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ end
@impl begin
struct ApplyComputePi end
function compute_pi(::Apply{J,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,J<:Tuple,O}

Expand All @@ -92,7 +92,7 @@ end

function send_lambda(::Apply{J,O},
lambda::Score{<:O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_idx::Integer)::Score where {N,J<:Tuple,O}
Expand Down
4 changes: 2 additions & 2 deletions src/sfuncs/compound/generate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ end
@impl begin
struct GenerateComputePi end
function compute_pi(::Generate{O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,O}

Expand All @@ -91,7 +91,7 @@ end

function send_lambda(::Generate{O},
lambda::Score{<:O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_idx::Integer)::Score where {N,O}
Expand Down
8 changes: 4 additions & 4 deletions src/sfuncs/compound/mixture.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ end
@impl begin
struct MixtureExpectedStats end
function expected_stats(sf::Mixture{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
pis::NTuple{M,Dist},
child_lambda::Score{<:O}) where {I,O,N,M}
Expand Down Expand Up @@ -115,7 +115,7 @@ end
end

function make_factors(sf::Mixture{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
id,
parids::Tuple)::Tuple{Vector{<:Scruff.Utils.Factor}, Vector{<:Scruff.Utils.Factor}} where {I,O,N}
Expand Down Expand Up @@ -159,7 +159,7 @@ end
@impl begin
struct MixtureComputePi end
function compute_pi(sf::Mixture{I,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,I,O}
function f(i)
Expand All @@ -176,7 +176,7 @@ end
struct MixtureSendLambda end
function send_lambda(sf::Mixture{I,O},
lambda::Score{<:O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_ix::Integer)::Score where {N,I,O}
Expand Down
6 changes: 3 additions & 3 deletions src/sfuncs/conddist/conditional.jl
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ end
end

function make_factors(sf::Conditional{I,J,K,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
id,
parids::Tuple)::Tuple{Vector{<:Scruff.Utils.Factor}, Vector{<:Scruff.Utils.Factor}} where {I,J,K,O,N}
Expand Down Expand Up @@ -240,7 +240,7 @@ end
struct ConditionalComputePi end

function compute_pi(sf::Conditional{I,J,K,O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple)::Dist{<:O} where {N,I,J,K,O}

Expand Down Expand Up @@ -275,7 +275,7 @@ end

function send_lambda(sf::Conditional{I,J,K,O},
lambda::Score{<:O},
range::__OptVec{<:O},
range::VectorOption{<:O},
parranges::NTuple{N,Vector},
incoming_pis::Tuple,
parent_ix::Integer)::Score where {N,I,J,K,O}
Expand Down
Loading

0 comments on commit 4989958

Please sign in to comment.