diff --git a/src/initdt.jl b/src/initdt.jl index 28a367bfe5..071b50f506 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -1,23 +1,56 @@ @muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, - prob::DiffEqBase.AbstractODEProblem{uType, tType, true - }, - integrator) where {tType, uType} + prob, integrator) + + iscomposite = !(typeof(integrator.alg) <: CompositeAlgorithm) + sk = if !(typeof(integrator.alg) <: CompositeAlgorithm) + tmpcache = get_tmp_cache(integrator) + if tmpcache === nothing + nothing + else + first(tmpcache) + end + else + nothing + end + + current_fsal = get_current_isfsal(integrator.alg, integrator.cache) + is_odeintegrator = typeof(integrator) <: ODEIntegrator + verbose = integrator.opts.verbose + alg_order = get_current_alg_order(integrator.alg, integrator.cache) + + linsolve = if hasproperty(integrator.alg, :linsolve) + integrator.alg.linsolve + else + nothing + end + + fsallast = if current_fsal + integrator.fsallast + else + nothing + end + + _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, + prob, integrator.p, integrator.opts.dtmin, integrator.isdae, iscomposite, + sk, fsallast, current_fsal, is_odeintegrator, verbose, alg_order, linsolve) +end + +@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, + prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {tType, uType} _tType = eltype(tType) f = prob.f - p = integrator.p oneunit_tType = oneunit(_tType) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(integrator.opts.dtmin) + dtmin = nextfloat(dtmin) smalldt = convert(_tType, oneunit_tType * 1 // 10^(6)) - if integrator.isdae + if isdae return tdir * max(smalldt, dtmin) end - if eltype(u0) <: Number && !(typeof(integrator.alg) <: CompositeAlgorithm) - cache = get_tmp_cache(integrator) - sk = first(cache) + if eltype(u0) <: Number && iscomposite if u0 isa Array && abstol isa Number && reltol isa Number @inbounds @simd ivdep for i in eachindex(u0) sk[i] = abstol + internalnorm(u0[i], t) * reltol @@ -36,10 +69,9 @@ end end - if get_current_isfsal(integrator.alg, integrator.cache) && - typeof(integrator) <: ODEIntegrator + if current_fsal && is_odeintegrator # Right now DelayDiffEq has issues with fsallast not being initialized - f₀ = integrator.fsallast + f₀ = fsallast f(f₀, u0, p, t) else # TODO: use more caches @@ -107,7 +139,7 @@ any(mm != I for mm in prob.f.mass_matrix)) ftmp = zero(f₀) try - integrator.alg.linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true) + linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true) copyto!(f₀, ftmp) catch return tdir * max(smalldt, dtmin) @@ -127,7 +159,7 @@ # Better than checking any(x->any(isnan, x), f₀) # because it also checks if partials are NaN # https://discourse.julialang.org/t/incorporating-forcing-functions-in-the-ode-model/70133/26 - if integrator.opts.verbose && isnan(d₁) + if verbose && isnan(d₁) @warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.") return tdir * dtmin end @@ -166,7 +198,7 @@ if prob.f.mass_matrix != I && (!(typeof(prob.f) <: DynamicalODEFunction) || any(mm != I for mm in prob.f.mass_matrix)) - integrator.alg.linsolve(ftmp, prob.f.mass_matrix, f₁, false) + linsolve(ftmp, prob.f.mass_matrix, f₁, false) copyto!(f₁, ftmp) end @@ -192,8 +224,7 @@ else dt₁ = convert(_tType, oneunit_tType * - 10.0^(-(2 + log10(max_d₁d₂)) / - get_current_alg_order(integrator.alg, integrator.cache))) + 10.0^(-(2 + log10(max_d₁d₂)) / alg_order)) end return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir)) end @@ -224,20 +255,20 @@ function Base.showerror(io::IO, e::TypeNotConstantError) println(io, e.f₀) end -@muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, +@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, prob::DiffEqBase.AbstractODEProblem{uType, tType, - false}, - integrator) where {uType, tType} + false}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {uType, tType} _tType = eltype(tType) f = prob.f p = prob.p oneunit_tType = oneunit(_tType) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(integrator.opts.dtmin) + dtmin = nextfloat(dtmin) smalldt = convert(_tType, oneunit_tType * 1 // 10^(6)) - if integrator.isdae + if isdae return tdir * max(smalldt, dtmin) end @@ -245,7 +276,7 @@ end d₀ = internalnorm(u0 ./ sk, t) f₀ = f(u0, p, t) - if integrator.opts.verbose && any(x -> any(isnan, x), f₀) + if verbose && any(x -> any(isnan, x), f₀) @warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.") end @@ -279,19 +310,18 @@ end dt₁ = max(smalldt, dt₀ * 1 // 10^(3)) else dt₁ = _tType(oneunit_tType * - 10^(-(2 + log10(max_d₁d₂)) / - get_current_alg_order(integrator.alg, integrator.cache))) + 10^(-(2 + log10(max_d₁d₂)) / alg_order)) end return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir)) end -@inline function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, +@inline function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, prob::DiffEqBase.AbstractDAEProblem{duType, uType, - tType}, - integrator) where {duType, uType, tType} + tType}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {duType, uType, tType} _tType = eltype(tType) tspan = prob.tspan init_dt = abs(tspan[2] - tspan[1]) init_dt = isfinite(init_dt) ? init_dt : oneunit(_tType) return convert(_tType, init_dt * 1 // 10^(6)) -end +end \ No newline at end of file diff --git a/src/integrators/integrator_interface.jl b/src/integrators/integrator_interface.jl index 699ff67244..12d2ffbeb7 100644 --- a/src/integrators/integrator_interface.jl +++ b/src/integrators/integrator_interface.jl @@ -108,7 +108,8 @@ end for typ in (OrdinaryDiffEqAlgorithm, Union{RadauIIA3, RadauIIA5}, OrdinaryDiffEqNewtonAdaptiveAlgorithm, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, - Union{SSPRK22, SSPRK33, SSPRK53_2N1, SSPRK53_2N2, SSPRK43, SSPRK432, SSPRK932}) + Union{SSPRK22, SSPRK33, SSPRK53_2N1, SSPRK53_2N2, SSPRK43, SSPRK432, SSPRK932}, + DAEAlgorithm) @eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::$typ, cache::OrdinaryDiffEqConstantCache) nothing