Skip to content

Commit

Permalink
Replace Expronicon with Moshi
Browse files Browse the repository at this point in the history
  • Loading branch information
visr committed Jan 24, 2025
1 parent 5cdb4ba commit 42f5119
Show file tree
Hide file tree
Showing 7 changed files with 38 additions and 44 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ DomainSets = "5b8099bc-c8ec-5219-889f-1d9e522a28bf"
DynamicQuantities = "06fc5a27-2a28-4c7c-a15d-362465fb6821"
EnumX = "4e289a0a-7415-4d19-859d-a7e5c4648b56"
ExprTools = "e2ba6199-217a-4e67-a87a-7c52f15ade04"
Expronicon = "6b7a57c9-7cc1-4fdf-b7f5-e857abae3636"
FindFirstFunctions = "64ca27bc-2ba2-4a57-88aa-44e436879224"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"
Expand All @@ -36,6 +35,7 @@ Latexify = "23fbe1c1-3f47-55db-b15f-69d7ec21a316"
Libdl = "8f399da3-3557-5675-b5ff-fb832c97cbdb"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
MLStyle = "d8e11817-5142-5d16-987a-aa16d5891078"
Moshi = "2e0e35c7-a2e4-4343-998d-7ef72827ed2d"
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
NonlinearSolve = "8913a72c-1f9b-4ce2-8d82-65094dcecaec"
OffsetArrays = "6fe1bfb0-de20-5000-8ca7-80f57d26f881"
Expand Down Expand Up @@ -101,7 +101,6 @@ DomainSets = "0.6, 0.7"
DynamicQuantities = "^0.11.2, 0.12, 0.13, 1"
EnumX = "1.0.4"
ExprTools = "0.1.10"
Expronicon = "0.8"
FindFirstFunctions = "1"
ForwardDiff = "0.10.3"
FunctionWrappers = "1.1"
Expand All @@ -118,6 +117,7 @@ Libdl = "1"
LinearAlgebra = "1"
MLStyle = "0.4.17"
ModelingToolkitStandardLibrary = "2.19"
Moshi = "0.3"
NaNMath = "0.3, 1"
NonlinearSolve = "4.3"
OffsetArrays = "1"
Expand Down
2 changes: 2 additions & 0 deletions src/ModelingToolkit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ using SciMLBase: StandardODEProblem, StandardNonlinearProblem, handle_varmap, Ti
using Distributed
import JuliaFormatter
using MLStyle
import Moshi
using Moshi.Data: @data
using NonlinearSolve
import SCCNonlinearSolve
using Reexport
Expand Down
26 changes: 9 additions & 17 deletions src/clock.jl
Original file line number Diff line number Diff line change
@@ -1,20 +1,12 @@
module InferredClock

export InferredTimeDomain

using Expronicon.ADT: @adt, @match
using SciMLBase: TimeDomain

@adt InferredTimeDomain begin
@data InferredClock begin
Inferred
InferredDiscrete
end

Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)
const InferredTimeDomain = InferredClock.Type
using .InferredClock: Inferred, InferredDiscrete

end

using .InferredClock
Base.Broadcast.broadcastable(x::InferredTimeDomain) = Ref(x)

struct VariableTimeDomain end
Symbolics.option_to_metadata_type(::Val{:timedomain}) = VariableTimeDomain
Expand All @@ -29,7 +21,7 @@ true if `x` contains only continuous-domain signals.
See also [`has_continuous_domain`](@ref)
"""
function is_continuous_domain(x)
issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous
issym(x) && return getmetadata(x, VariableTimeDomain, false) == Continuous()
!has_discrete_domain(x) && has_continuous_domain(x)
end

Expand Down Expand Up @@ -58,8 +50,8 @@ has_time_domain(x::Num) = has_time_domain(value(x))
has_time_domain(x) = false

for op in [Differential]
@eval input_timedomain(::$op, arg = nothing) = Continuous
@eval output_timedomain(::$op, arg = nothing) = Continuous
@eval input_timedomain(::$op, arg = nothing) = Continuous()
@eval output_timedomain(::$op, arg = nothing) = Continuous()
end

"""
Expand Down Expand Up @@ -104,8 +96,8 @@ function is_discrete_domain(x)
!has_discrete_domain(x) && has_continuous_domain(x)
end

sampletime(c) = @match c begin
PeriodicClock(dt, _...) => dt
sampletime(c) = Moshi.Match.@match c begin
PeriodicClock(dt) => dt
_ => nothing
end

Expand Down
12 changes: 6 additions & 6 deletions src/discretedomain.jl
Original file line number Diff line number Diff line change
Expand Up @@ -226,28 +226,28 @@ Base.:-(k::ShiftIndex, i::Int) = k + (-i)
"""
input_timedomain(op::Operator)
Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` operates on.
Return the time-domain type (`Continuous()` or `InferredDiscrete()`) that `op` operates on.
"""
function input_timedomain(s::Shift, arg = nothing)
if has_time_domain(arg)
return get_time_domain(arg)
end
InferredDiscrete
InferredDiscrete()
end

"""
output_timedomain(op::Operator)
Return the time-domain type (`Continuous` or `InferredDiscrete`) that `op` results in.
Return the time-domain type (`Continuous()` or `InferredDiscrete()`) that `op` results in.
"""
function output_timedomain(s::Shift, arg = nothing)
if has_time_domain(t, arg)
return get_time_domain(t, arg)
end
InferredDiscrete
InferredDiscrete()
end

input_timedomain(::Sample, _ = nothing) = Continuous
input_timedomain(::Sample, _ = nothing) = Continuous()
output_timedomain(s::Sample, _ = nothing) = s.clock

function input_timedomain(h::Hold, arg = nothing)
Expand All @@ -256,7 +256,7 @@ function input_timedomain(h::Hold, arg = nothing)
end
InferredDiscrete # the Hold accepts any discrete
end
output_timedomain(::Hold, _ = nothing) = Continuous
output_timedomain(::Hold, _ = nothing) = Continuous()

sampletime(op::Sample, _ = nothing) = sampletime(op.clock)
sampletime(op::ShiftIndex, _ = nothing) = sampletime(op.clock)
Expand Down
6 changes: 3 additions & 3 deletions src/systems/clock_inference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ end
function ClockInference(ts::TransformationState)
@unpack structure = ts
@unpack graph = structure
eq_domain = TimeDomain[Continuous for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous for _ in 1:ndsts(graph)]
eq_domain = TimeDomain[Continuous() for _ in 1:nsrcs(graph)]
var_domain = TimeDomain[Continuous() for _ in 1:ndsts(graph)]
inferred = BitSet()
for (i, v) in enumerate(get_fullvars(ts))
d = get_time_domain(ts, v)
Expand Down Expand Up @@ -151,7 +151,7 @@ function split_system(ci::ClockInference{S}) where {S}
get!(clock_to_id, d) do
cid = (cid_counter[] += 1)
push!(id_to_clock, d)
if d == Continuous
if d == Continuous()
continuous_id[] = cid
end
cid
Expand Down
2 changes: 1 addition & 1 deletion src/systems/systemstructure.jl
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ function structural_simplify!(state::TearingState, io = nothing; simplify = fals
Dict(v => 0.0 for v in Iterators.flatten(inputs)))
end
ps = [sym isa CallWithMetadata ? sym :
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous))
setmetadata(sym, VariableTimeDomain, get(time_domains, sym, Continuous()))
for sym in get_ps(sys)]
@set! sys.ps = ps
else
Expand Down
30 changes: 15 additions & 15 deletions test/clock.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,19 +77,19 @@ k = ShiftIndex(d)

d = Clock(dt)
# Note that TearingState reorders the equations
@test eqmap[1] == Continuous
@test eqmap[1] == Continuous()
@test eqmap[2] == d
@test eqmap[3] == d
@test eqmap[4] == d
@test eqmap[5] == Continuous
@test eqmap[6] == Continuous
@test eqmap[5] == Continuous()
@test eqmap[6] == Continuous()

@test varmap[yd] == d
@test varmap[ud] == d
@test varmap[r] == d
@test varmap[x] == Continuous
@test varmap[y] == Continuous
@test varmap[u] == Continuous
@test varmap[x] == Continuous()
@test varmap[y] == Continuous()
@test varmap[u] == Continuous()

@info "Testing shift normalization"
dt = 0.1
Expand Down Expand Up @@ -192,10 +192,10 @@ eqs = [yd ~ Sample(dt)(y)
@test varmap[ud1] == d
@test varmap[yd2] == d2
@test varmap[ud2] == d2
@test varmap[r] == Continuous
@test varmap[x] == Continuous
@test varmap[y] == Continuous
@test varmap[u] == Continuous
@test varmap[r] == Continuous()
@test varmap[x] == Continuous()
@test varmap[y] == Continuous()
@test varmap[u] == Continuous()

@info "test composed systems"

Expand Down Expand Up @@ -241,14 +241,14 @@ eqs = [yd ~ Sample(dt)(y)
ci, varmap = infer_clocks(cl)

@test varmap[f.x] == Clock(0.5)
@test varmap[p.x] == Continuous
@test varmap[p.y] == Continuous
@test varmap[p.x] == Continuous()
@test varmap[p.y] == Continuous()
@test varmap[c.ud] == Clock(0.5)
@test varmap[c.yd] == Clock(0.5)
@test varmap[c.y] == Continuous
@test varmap[c.y] == Continuous()
@test varmap[f.y] == Clock(0.5)
@test varmap[f.u] == Clock(0.5)
@test varmap[p.u] == Continuous
@test varmap[p.u] == Continuous()
@test varmap[c.r] == Clock(0.5)

## Multiple clock rates
Expand Down Expand Up @@ -474,7 +474,7 @@ eqs = [yd ~ Sample(dt)(y)

## Test continuous clock

c = ModelingToolkit.SolverStepClock
c = ModelingToolkit.SolverStepClock()
k = ShiftIndex(c)

@mtkmodel CounterSys begin
Expand Down

0 comments on commit 42f5119

Please sign in to comment.