From 71c8184740debff351404475ca6e7bdf63e52faf Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 16:53:26 -0400 Subject: [PATCH 1/5] Split the core of initdt to be algorithm independent This should reduce precompilation --- src/initdt.jl | 82 +++++++++++++++++++++++++++++++++------------------ 1 file changed, 53 insertions(+), 29 deletions(-) diff --git a/src/initdt.jl b/src/initdt.jl index 28a367bfe5..57f473f978 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -1,23 +1,50 @@ @muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, - prob::DiffEqBase.AbstractODEProblem{uType, tType, true - }, - integrator) where {tType, uType} + prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, integrator) where {tType, uType} + + sk = if !(typeof(integrator.alg) <: CompositeAlgorithm) + first(get_tmp_cache(integrator) + else + nothing + end + + current_fsal = get_current_isfsal(integrator.alg, integrator.cache) + is_odeintegrator = typeof(integrator) <: ODEIntegrator + verbose = integrator.opts.verbose + alg_order = get_current_alg_order(integrator.alg, integrator.cache) + + linsolve = if haskey(integrator.alg, :linsolve) + integrator.alg.linsolve + else + nothing + end + + fsallast = if current_fsal + integrator.fsallast + else + nothing + end + + _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, + prob, integrator.p, integrator.opts.dtmin, integrator.isdae, iscomposite, + sk, fsallast, current_fsal, is_odeintegrator, verbose, alg_order, linsolve) +end + +@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, + prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {tType, uType} _tType = eltype(tType) f = prob.f - p = integrator.p oneunit_tType = oneunit(_tType) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(integrator.opts.dtmin) + dtmin = nextfloat(dtmin) smalldt = convert(_tType, oneunit_tType * 1 // 10^(6)) - if integrator.isdae + if isdae return tdir * max(smalldt, dtmin) end - if eltype(u0) <: Number && !(typeof(integrator.alg) <: CompositeAlgorithm) - cache = get_tmp_cache(integrator) - sk = first(cache) + if eltype(u0) <: Number && iscomposite if u0 isa Array && abstol isa Number && reltol isa Number @inbounds @simd ivdep for i in eachindex(u0) sk[i] = abstol + internalnorm(u0[i], t) * reltol @@ -36,10 +63,9 @@ end end - if get_current_isfsal(integrator.alg, integrator.cache) && - typeof(integrator) <: ODEIntegrator + if current_fsal && is_odeintegrator # Right now DelayDiffEq has issues with fsallast not being initialized - f₀ = integrator.fsallast + f₀ = fsallast f(f₀, u0, p, t) else # TODO: use more caches @@ -107,7 +133,7 @@ any(mm != I for mm in prob.f.mass_matrix)) ftmp = zero(f₀) try - integrator.alg.linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true) + linsolve(ftmp, copy(prob.f.mass_matrix), f₀, true) copyto!(f₀, ftmp) catch return tdir * max(smalldt, dtmin) @@ -127,7 +153,7 @@ # Better than checking any(x->any(isnan, x), f₀) # because it also checks if partials are NaN # https://discourse.julialang.org/t/incorporating-forcing-functions-in-the-ode-model/70133/26 - if integrator.opts.verbose && isnan(d₁) + if verbose && isnan(d₁) @warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.") return tdir * dtmin end @@ -166,7 +192,7 @@ if prob.f.mass_matrix != I && (!(typeof(prob.f) <: DynamicalODEFunction) || any(mm != I for mm in prob.f.mass_matrix)) - integrator.alg.linsolve(ftmp, prob.f.mass_matrix, f₁, false) + linsolve(ftmp, prob.f.mass_matrix, f₁, false) copyto!(f₁, ftmp) end @@ -192,8 +218,7 @@ else dt₁ = convert(_tType, oneunit_tType * - 10.0^(-(2 + log10(max_d₁d₂)) / - get_current_alg_order(integrator.alg, integrator.cache))) + 10.0^(-(2 + log10(max_d₁d₂)) / alg_order)) end return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir)) end @@ -224,20 +249,20 @@ function Base.showerror(io::IO, e::TypeNotConstantError) println(io, e.f₀) end -@muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, +@muladd function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, prob::DiffEqBase.AbstractODEProblem{uType, tType, - false}, - integrator) where {uType, tType} + false}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {uType, tType} _tType = eltype(tType) f = prob.f p = prob.p oneunit_tType = oneunit(_tType) dtmax_tdir = tdir * dtmax - dtmin = nextfloat(integrator.opts.dtmin) + dtmin = nextfloat(dtmin) smalldt = convert(_tType, oneunit_tType * 1 // 10^(6)) - if integrator.isdae + if isdae return tdir * max(smalldt, dtmin) end @@ -245,7 +270,7 @@ end d₀ = internalnorm(u0 ./ sk, t) f₀ = f(u0, p, t) - if integrator.opts.verbose && any(x -> any(isnan, x), f₀) + if verbose && any(x -> any(isnan, x), f₀) @warn("First function call produced NaNs. Exiting. Double check that none of the initial conditions, parameters, or timespan values are NaN.") end @@ -279,19 +304,18 @@ end dt₁ = max(smalldt, dt₀ * 1 // 10^(3)) else dt₁ = _tType(oneunit_tType * - 10^(-(2 + log10(max_d₁d₂)) / - get_current_alg_order(integrator.alg, integrator.cache))) + 10^(-(2 + log10(max_d₁d₂)) / alg_order)) end return tdir * max(dtmin, min(100dt₀, dt₁, dtmax_tdir)) end -@inline function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, +@inline function _ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, prob::DiffEqBase.AbstractDAEProblem{duType, uType, - tType}, - integrator) where {duType, uType, tType} + tType}, p, dtmin, isdae, iscomposite, sk, fsallast, current_fsal, is_odeintegrator, + verbose, alg_order, linsolve) where {duType, uType, tType} _tType = eltype(tType) tspan = prob.tspan init_dt = abs(tspan[2] - tspan[1]) init_dt = isfinite(init_dt) ? init_dt : oneunit(_tType) return convert(_tType, init_dt * 1 // 10^(6)) -end +end \ No newline at end of file From f3ac9554ce016d8daff41fbe57ac4b3461b53b80 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 16:57:55 -0400 Subject: [PATCH 2/5] typo --- src/initdt.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/initdt.jl b/src/initdt.jl index 57f473f978..e6e4ef6e49 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -2,7 +2,7 @@ prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, integrator) where {tType, uType} sk = if !(typeof(integrator.alg) <: CompositeAlgorithm) - first(get_tmp_cache(integrator) + first(get_tmp_cache(integrator)) else nothing end From 62d984f6d1fc7dd5eabd23da6898223e9ff32d99 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 17:05:35 -0400 Subject: [PATCH 3/5] fix some dispatching --- src/initdt.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/initdt.jl b/src/initdt.jl index e6e4ef6e49..4ae08e4d87 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -1,6 +1,7 @@ @muladd function ode_determine_initdt(u0, t, tdir, dtmax, abstol, reltol, internalnorm, - prob::DiffEqBase.AbstractODEProblem{uType, tType, true}, integrator) where {tType, uType} + prob, integrator) + iscomposite = !(typeof(integrator.alg) <: CompositeAlgorithm) sk = if !(typeof(integrator.alg) <: CompositeAlgorithm) first(get_tmp_cache(integrator)) else @@ -12,7 +13,7 @@ verbose = integrator.opts.verbose alg_order = get_current_alg_order(integrator.alg, integrator.cache) - linsolve = if haskey(integrator.alg, :linsolve) + linsolve = if hasproperty(integrator.alg, :linsolve) integrator.alg.linsolve else nothing From bf6bff3adcd1fde1b8056d6b0584eb287ea9de65 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 17:10:49 -0400 Subject: [PATCH 4/5] fix tmpcache handling --- src/initdt.jl | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/src/initdt.jl b/src/initdt.jl index 4ae08e4d87..071b50f506 100644 --- a/src/initdt.jl +++ b/src/initdt.jl @@ -3,7 +3,12 @@ iscomposite = !(typeof(integrator.alg) <: CompositeAlgorithm) sk = if !(typeof(integrator.alg) <: CompositeAlgorithm) - first(get_tmp_cache(integrator)) + tmpcache = get_tmp_cache(integrator) + if tmpcache === nothing + nothing + else + first(tmpcache) + end else nothing end From 1328a0c365f5d807868bac5f61c2871520b8bac8 Mon Sep 17 00:00:00 2001 From: Chris Rackauckas Date: Tue, 1 Aug 2023 17:24:31 -0400 Subject: [PATCH 5/5] handle nothing tmp for DAEAlgorithms --- src/integrators/integrator_interface.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/integrators/integrator_interface.jl b/src/integrators/integrator_interface.jl index 699ff67244..12d2ffbeb7 100644 --- a/src/integrators/integrator_interface.jl +++ b/src/integrators/integrator_interface.jl @@ -108,7 +108,8 @@ end for typ in (OrdinaryDiffEqAlgorithm, Union{RadauIIA3, RadauIIA5}, OrdinaryDiffEqNewtonAdaptiveAlgorithm, OrdinaryDiffEqRosenbrockAdaptiveAlgorithm, - Union{SSPRK22, SSPRK33, SSPRK53_2N1, SSPRK53_2N2, SSPRK43, SSPRK432, SSPRK932}) + Union{SSPRK22, SSPRK33, SSPRK53_2N1, SSPRK53_2N2, SSPRK43, SSPRK432, SSPRK932}, + DAEAlgorithm) @eval @inline function DiffEqBase.get_tmp_cache(integrator, alg::$typ, cache::OrdinaryDiffEqConstantCache) nothing