Skip to content

Commit

Permalink
Merge pull request #3337 from hersle/fix_remake_dummy_derivative
Browse files Browse the repository at this point in the history
Fix dual type promotion in remake with dummy derivatives
  • Loading branch information
ChrisRackauckas authored Jan 18, 2025
2 parents 009d8b8 + 60e4723 commit 8668cde
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
15 changes: 15 additions & 0 deletions src/systems/diffeqs/abstractodesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1373,6 +1373,21 @@ function InitializationProblem{iip, specialize}(sys::AbstractSystem,
end

u0map = merge(ModelingToolkit.guesses(sys), todict(guesses), todict(u0map))

# Replace dummy derivatives in u0map: D(x) -> x_t etc.
if has_schedule(sys)
schedule = get_schedule(sys)
if !isnothing(schedule)
for (var, val) in u0map
dvar = get(schedule.dummy_sub, var, var) # with dummy derivatives
if dvar !== var # then replace it
delete!(u0map, var)
push!(u0map, dvar => val)
end
end
end
end

fullmap = merge(u0map, parammap)
u0T = Union{}
for sym in unknowns(isys)
Expand Down
21 changes: 21 additions & 0 deletions test/extensions/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using Zygote
using SymbolicIndexingInterface
using SciMLStructures
using OrdinaryDiffEq
using NonlinearSolve
using SciMLSensitivity
using ForwardDiff
using ChainRulesCore
Expand Down Expand Up @@ -103,3 +104,23 @@ vals = (1.0f0, 3ones(Float32, 3))
tangent = rand_tangent(ps)
fwd, back = ChainRulesCore.rrule(remake_buffer, sys, ps, idxs, vals)
@inferred back(tangent)

@testset "Dual type promotion in remake with dummy derivatives" begin # https://github.com/SciML/ModelingToolkit.jl/issues/3336
# Throw ball straight up into the air
@variables y(t)
eqs = [D(D(y)) ~ -9.81]
initialization_eqs = [y^2 ~ 0] # initialize y = 0 in a way that builds an initialization problem
@named sys = ODESystem(eqs, t; initialization_eqs)
sys = structural_simplify(sys)

# Find initial throw velocity that reaches exactly 10 m after 1 s
dprob0 = ODEProblem(sys, [D(y) => NaN], (0.0, 1.0), []; guesses = [y => 0.0])
function f(ics, _)
dprob = remake(dprob0, u0 = Dict(D(y) => ics[1]))
dsol = solve(dprob, Tsit5())
return [dsol[y][end] - 10.0]
end
nprob = NonlinearProblem(f, [1.0])
nsol = solve(nprob, NewtonRaphson())
@test nsol[1] 10.0 / 1.0 + 9.81 * 1.0 / 2 # anal free fall solution is y = v0*t - g*t^2/2 -> v0 = y/t + g*t/2
end

0 comments on commit 8668cde

Please sign in to comment.