Skip to content

Commit

Permalink
Merge pull request #3169 from AayushSabharwal/as/dde-non-tunables
Browse files Browse the repository at this point in the history
fix: support non-tunable parameters with DDEs
  • Loading branch information
ChrisRackauckas authored Oct 30, 2024
2 parents 2c694b5 + d128997 commit 2a1024a
Show file tree
Hide file tree
Showing 3 changed files with 43 additions and 11 deletions.
20 changes: 10 additions & 10 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
wrap_code = identity,
kwargs...)
if isdde
eqs = delay_to_function(sys)
issplit = has_index_cache(sys) && get_index_cache(sys) !== nothing
eqs = delay_to_function(
sys; history_arg = issplit ? MTKPARAMETERS_ARG : DEFAULT_PARAMS_ARG)
else
eqs = [eq for eq in equations(sys)]
end
Expand All @@ -211,7 +213,10 @@ function generate_function(sys::AbstractODESystem, dvs = unknowns(sys),
t = get_iv(sys)

if isdde
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...)
build_function(rhss, u, DDE_HISTORY_FUN, p..., t; kwargs...,
wrap_code = wrap_code .∘ wrap_mtkparameters(sys, false, 3) .∘
wrap_array_vars(sys, rhss; dvs, ps, history = true) .∘
wrap_parameter_dependencies(sys, false))
else
pre, sol_states = get_substitutions_and_solved_unknowns(sys)

Expand Down Expand Up @@ -570,9 +575,7 @@ function DiffEqBase.DDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys)
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(u, h, p, t) = f_oop(u, h, p, t)
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
f(du, u, h, p::MTKParameters, t) = f_iip(du, u, h, p..., t)

DDEFunction{iip}(f, sys = sys)
end
Expand All @@ -595,17 +598,14 @@ function DiffEqBase.SDDEFunction{iip}(sys::AbstractODESystem, dvs = unknowns(sys
expression_module = eval_module, checkbounds = checkbounds,
kwargs...)
f_oop, f_iip = eval_or_rgf.(f_gen; eval_expression, eval_module)
f(u, h, p, t) = f_oop(u, h, p, t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)

g_gen = generate_diffusion_function(sys, dvs, ps; expression = Val{true},
isdde = true, kwargs...)
g_oop, g_iip = eval_or_rgf.(g_gen; eval_expression, eval_module)
f(u, h, p, t) = f_oop(u, h, p, t)
f(u, h, p::MTKParameters, t) = f_oop(u, h, p..., t)
f(du, u, h, p, t) = f_iip(du, u, h, p, t)
f(du, u, h, p::MTKParameters, t) = f_iip(du, u, h, p..., t)
g(u, h, p, t) = g_oop(u, h, p, t)
g(u, h, p::MTKParameters, t) = g_oop(u, h, p..., t)
g(du, u, h, p, t) = g_iip(du, u, h, p, t)
g(du, u, h, p::MTKParameters, t) = g_iip(du, u, h, p..., t)

SDDEFunction{iip}(f, g, sys = sys)
end
Expand Down
6 changes: 5 additions & 1 deletion src/systems/diffeqs/sdesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,11 @@ function generate_diffusion_function(sys::SDESystem, dvs = unknowns(sys),
(map(x -> time_varying_as_func(value(x), sys), ps),)
end
if isdde
return build_function(eqs, u, DDE_HISTORY_FUN, p..., get_iv(sys); kwargs...)
return build_function(eqs, u, DDE_HISTORY_FUN, p..., get_iv(sys); kwargs...,
wrap_code = get(kwargs, :wrap_code, identity) .∘
wrap_mtkparameters(sys, false, 3) .∘
wrap_array_vars(sys, eqs; dvs, ps, history = true) .∘
wrap_parameter_dependencies(sys, false))
else
return build_function(eqs, u, p..., get_iv(sys); kwargs...)
end
Expand Down
28 changes: 28 additions & 0 deletions test/dde.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,3 +170,31 @@ prob_sa = DDEProblem(sys, [], (0.0, 10.0); constant_lags = [sys.osc1.τ, sys.osc
sol(sol.t .- prob.ps[ssys.valve.τ]; idxs = ssys.valve.opening).u .+
sum.(sol[ssys.vvecs.x])
end

@testset "Issue#3165 DDEs with non-tunables" begin
@variables x(..) = 1.0
@parameters w=1.0 [tunable = false] τ=0.5
eqs = [D(x(t)) ~ -w * x(t - τ)]

@named sys = System(eqs, t)
sys = structural_simplify(sys)

prob = DDEProblem(sys,
[],
(0.0, 10.0),
constant_lags = [τ])

alg = MethodOfSteps(Vern7())
@test_nowarn solve(prob, alg)

@brownian r
eqs = [D(x(t)) ~ -w * x(t - τ) + r]
@named sys = System(eqs, t)
sys = structural_simplify(sys)
prob = SDDEProblem(sys,
[],
(0.0, 10.0),
constant_lags = [τ])

@test_nowarn solve(prob, RKMil())
end

0 comments on commit 2a1024a

Please sign in to comment.