Skip to content

Commit

Permalink
simplify dolinsolve
Browse files Browse the repository at this point in the history
  • Loading branch information
oscardssmith committed Aug 24, 2024
1 parent b609a4d commit 1e1d569
Show file tree
Hide file tree
Showing 21 changed files with 348 additions and 603 deletions.
91 changes: 38 additions & 53 deletions lib/OrdinaryDiffEqBDF/src/algorithms.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@ an Adaptive BDF2 Formula and Comparison with The MATLAB Ode15s. Procedia Compute
ABDF2: Multistep Method
An adaptive order 2 L-stable fixed leading coefficient multistep BDF method.
"""
struct ABDF2{CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
struct ABDF2{CS, AD, F, F2, FDT, ST, CJ, K, T, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
κ::K
tol::T
smooth_est::Bool
Expand All @@ -20,14 +19,14 @@ struct ABDF2{CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
end
function ABDF2(; chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
κ = nothing, tol = nothing, linsolve = nothing, precs = DEFAULT_PRECS,
κ = nothing, tol = nothing, linsolve = nothing,
nlsolve = NLNewton(),
smooth_est = true, extrapolant = :linear,
controller = :Standard, step_limiter! = trivial_limiter!)
ABDF2{
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(tol), typeof(step_limiter!)}(linsolve, nlsolve, precs, κ, tol,
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(tol), typeof(step_limiter!)}(linsolve, nlsolve, κ, tol,
smooth_est, extrapolant, controller, step_limiter!)
end

Expand All @@ -36,11 +35,10 @@ Uri M. Ascher, Steven J. Ruuth, Brian T. R. Wetton. Implicit-Explicit Methods fo
Dependent Partial Differential Equations. 1995 Society for Industrial and Applied Mathematics
Journal on Numerical Analysis, 32(3), pp 797-823, 1995. doi: https://doi.org/10.1137/0732037
"""
struct SBDF{CS, AD, F, F2, P, FDT, ST, CJ, K, T} <:
struct SBDF{CS, AD, F, F2, FDT, ST, CJ, K, T} <:
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
κ::K
tol::T
extrapolant::Symbol
Expand All @@ -50,14 +48,13 @@ end

function SBDF(order; chunk_size = Val{0}(), autodiff = Val{true}(),
standardtag = Val{true}(), concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear, ark = false)
SBDF{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(tol)}(linsolve,
nlsolve,
precs,
κ,
tol,
extrapolant,
Expand All @@ -68,15 +65,14 @@ end
# All keyword form needed for remake
function SBDF(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear,
order, ark = false)
SBDF{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(κ), typeof(tol)}(linsolve,
nlsolve,
precs,
κ,
tol,
extrapolant,
Expand Down Expand Up @@ -136,11 +132,10 @@ Optional parameter kappa defaults to Shampine's accuracy-optimal -0.1850.
See also `QNDF`.
"""
struct QNDF1{CS, AD, F, F2, P, FDT, ST, CJ, κType, StepLimiter} <:
struct QNDF1{CS, AD, F, F2, FDT, ST, CJ, κType, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
extrapolant::Symbol
kappa::κType
controller::Symbol
Expand All @@ -149,15 +144,14 @@ end

function QNDF1(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
linsolve = nothing, nlsolve = NLNewton(),
extrapolant = :linear, kappa = -37//200,
controller = :Standard, step_limiter! = trivial_limiter!)
QNDF1{
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(kappa), typeof(step_limiter!)}(linsolve,
nlsolve,
precs,
extrapolant,
kappa,
controller,
Expand All @@ -170,11 +164,10 @@ An adaptive order 2 quasi-constant timestep L-stable numerical differentiation f
See also `QNDF`.
"""
struct QNDF2{CS, AD, F, F2, P, FDT, ST, CJ, κType, StepLimiter} <:
struct QNDF2{CS, AD, F, F2, FDT, ST, CJ, κType, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
extrapolant::Symbol
kappa::κType
controller::Symbol
Expand All @@ -183,15 +176,14 @@ end

function QNDF2(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
linsolve = nothing, nlsolve = NLNewton(),
extrapolant = :linear, kappa = -1 // 9,
controller = :Standard, step_limiter! = trivial_limiter!)
QNDF2{
_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve), typeof(nlsolve),
typeof(precs), diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
diff_type, _unwrap_val(standardtag), _unwrap_val(concrete_jac),
typeof(kappa), typeof(step_limiter!)}(linsolve,
nlsolve,
precs,
extrapolant,
kappa,
controller,
Expand All @@ -214,12 +206,11 @@ year={1997},
publisher={SIAM}
}
"""
struct QNDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, κType, StepLimiter} <:
struct QNDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T, κType, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
max_order::Val{MO}
linsolve::F
nlsolve::F2
precs::P
κ::K
tol::T
extrapolant::Symbol
Expand All @@ -231,15 +222,15 @@ end
function QNDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear, kappa = (-37//200, -1//9, -823//10000, -83//2000, 0//1),
controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
QNDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac),
typeof(κ), typeof(tol), typeof(kappa), typeof(step_limiter!)}(
max_order, linsolve, nlsolve, precs, κ, tol,
max_order, linsolve, nlsolve, κ, tol,
extrapolant, kappa, controller, step_limiter!)
end

Expand All @@ -250,22 +241,20 @@ MEBDF2: Multistep Method
The second order Modified Extended BDF method, which has improved stability properties over the standard BDF.
Fixed timestep only.
"""
struct MEBDF2{CS, AD, F, F2, P, FDT, ST, CJ} <:
struct MEBDF2{CS, AD, F, F2, FDT, ST, CJ} <:
OrdinaryDiffEqNewtonAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
extrapolant::Symbol
end
function MEBDF2(; chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
linsolve = nothing, nlsolve = NLNewton(),
extrapolant = :constant)
MEBDF2{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve,
nlsolve,
precs,
extrapolant)
end

Expand All @@ -282,12 +271,11 @@ year={2002},
publisher={Walter de Gruyter GmbH \\& Co. KG}
}
"""
struct FBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T, StepLimiter} <:
struct FBDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T, StepLimiter} <:
OrdinaryDiffEqNewtonAdaptiveAlgorithm{CS, AD, FDT, ST, CJ}
max_order::Val{MO}
linsolve::F
nlsolve::F2
precs::P
κ::K
tol::T
extrapolant::Symbol
Expand All @@ -298,14 +286,14 @@ end
function FBDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear, controller = :Standard, step_limiter! = trivial_limiter!) where {MO}
FBDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac),
typeof(κ), typeof(tol), typeof(step_limiter!)}(
max_order, linsolve, nlsolve, precs, κ, tol, extrapolant,
max_order, linsolve, nlsolve, κ, tol, extrapolant,
controller, step_limiter!)
end

Expand Down Expand Up @@ -389,41 +377,39 @@ See also `SBDF`, `IMEXEuler`.
"""
IMEXEulerARK(; kwargs...) = SBDF(1; ark = true, kwargs...)

struct DImplicitEuler{CS, AD, F, F2, P, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
struct DImplicitEuler{CS, AD, F, F2, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
extrapolant::Symbol
controller::Symbol
end
function DImplicitEuler(;
chunk_size = Val{0}(), autodiff = true, standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
linsolve = nothing, nlsolve = NLNewton(),
extrapolant = :constant,
controller = :Standard)
DImplicitEuler{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve,
nlsolve, precs, extrapolant, controller)
nlsolve, extrapolant, controller)
end

struct DABDF2{CS, AD, F, F2, P, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
struct DABDF2{CS, AD, F, F2, FDT, ST, CJ} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
linsolve::F
nlsolve::F2
precs::P
extrapolant::Symbol
controller::Symbol
end
function DABDF2(; chunk_size = Val{0}(), autodiff = Val{true}(), standardtag = Val{true}(),
concrete_jac = nothing, diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(),
linsolve = nothing, nlsolve = NLNewton(),
extrapolant = :constant,
controller = :Standard)
DABDF2{_unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac)}(linsolve,
nlsolve, precs, extrapolant, controller)
nlsolve, extrapolant, controller)
end

#=
Expand All @@ -440,11 +426,10 @@ DBDF(;chunk_size=Val{0}(),autodiff=Val{true}(), standardtag = Val{true}(), concr
linsolve,nlsolve,precs,extrapolant)
=#

struct DFBDF{MO, CS, AD, F, F2, P, FDT, ST, CJ, K, T} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
struct DFBDF{MO, CS, AD, F, F2, FDT, ST, CJ, K, T} <: DAEAlgorithm{CS, AD, FDT, ST, CJ}
max_order::Val{MO}
linsolve::F
nlsolve::F2
precs::P
κ::K
tol::T
extrapolant::Symbol
Expand All @@ -453,13 +438,13 @@ end
function DFBDF(; max_order::Val{MO} = Val{5}(), chunk_size = Val{0}(),
autodiff = Val{true}(), standardtag = Val{true}(), concrete_jac = nothing,
diff_type = Val{:forward},
linsolve = nothing, precs = DEFAULT_PRECS, nlsolve = NLNewton(), κ = nothing,
linsolve = nothing, nlsolve = NLNewton(), κ = nothing,
tol = nothing,
extrapolant = :linear, controller = :Standard) where {MO}
DFBDF{MO, _unwrap_val(chunk_size), _unwrap_val(autodiff), typeof(linsolve),
typeof(nlsolve), typeof(precs), diff_type, _unwrap_val(standardtag),
typeof(nlsolve), diff_type, _unwrap_val(standardtag),
_unwrap_val(concrete_jac),
typeof(κ), typeof(tol)}(max_order, linsolve, nlsolve, precs, κ, tol, extrapolant,
typeof(κ), typeof(tol)}(max_order, linsolve, nlsolve, κ, tol, extrapolant,
controller)
end

Expand Down
2 changes: 1 addition & 1 deletion lib/OrdinaryDiffEqCore/src/OrdinaryDiffEqCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ end
end
const TryAgain = SlowConvergence

DEFAULT_PRECS(W, du, u, p, t, newW, Plprev, Prprev, solverdata) = nothing, nothing
DEFAULT_PRECS(W, p) = nothing, nothing
isdiscretecache(cache) = false

include("doc_utils.jl")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import StaticArrays: SArray, MVector, SVector, @SVector, StaticArray, MMatrix, S
using DiffEqBase: TimeGradientWrapper,
UJacobianWrapper, TimeDerivativeWrapper,
UDerivativeWrapper
using SciMLBase: AbstractSciMLOperator
using SciMLBase: AbstractSciMLOperator, DEIntegrator
import OrdinaryDiffEqCore
using OrdinaryDiffEqCore: OrdinaryDiffEqAlgorithm, OrdinaryDiffEqAdaptiveImplicitAlgorithm,
DAEAlgorithm,
Expand Down
47 changes: 21 additions & 26 deletions lib/OrdinaryDiffEqDifferentiation/src/linsolve_utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,20 @@ issuccess_W(W::Number) = !iszero(W)
issuccess_W(::Any) = true

function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothing,
du = nothing, u = nothing, p = nothing, t = nothing,
weight = nothing, solverdata = nothing,
reltol = integrator === nothing ? nothing : integrator.opts.reltol)
A !== nothing && (linsolve.A = A)
b !== nothing && (linsolve.b = b)
linu !== nothing && (linsolve.u = linu)

Plprev = linsolve.Pl isa LinearSolve.ComposePreconditioner ? linsolve.Pl.outer :
linsolve.Pl
Prprev = linsolve.Pr isa LinearSolve.ComposePreconditioner ? linsolve.Pr.outer :
linsolve.Pr

_alg = unwrap_alg(integrator, true)

_Pl, _Pr = _alg.precs(linsolve.A, du, u, p, t, A !== nothing, Plprev, Prprev,
solverdata)
if (_Pl !== nothing || _Pr !== nothing)
__Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pl
__Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(integrator.u)) : _Pr
linsolve.Pl = __Pl
linsolve.Pr = __Pr
if !isnothing(A)
if integrator isa DEIntegrator
(;u, p, t) = integrator
du = hasproperty(integrator, :du) ? integrator.du : nothing
p = (du, u, p, t)
reinit!(linsolve; A, p)
else
reinit!(linsolve; A)
end
end

linres = solve!(linsolve; reltol)
Expand All @@ -44,16 +37,18 @@ function dolinsolve(integrator, linsolve; A = nothing, linu = nothing, b = nothi
return linres
end

function wrapprecs(_Pl::Nothing, _Pr::Nothing, weight, u)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
Pl, Pr
end

function wrapprecs(_Pl, _Pr, weight, u)
Pl = _Pl === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pl
Pr = _Pr === nothing ? SciMLOperators.IdentityOperator(length(u)) : _Pr
Pl, Pr
function wrapprecs(linsolver, W, weight)
if isnothing(linsolver)
linsolver = LinearSolve.defaultalg(W, weight, LinearSolve.OperatorAssumptions(true))
end
if hasproperty(linsolver, :precs) && isnothing(linsolver.precs)
Pl = LinearSolve.InvPreconditioner(Diagonal(_vec(weight)))
Pr = Diagonal(_vec(weight))
precs = Returns((Pl, Pr))
return remake(linsolver; precs)
else
return linsolver
end
end

Base.resize!(p::LinearSolve.LinearCache, i) = p
Loading

0 comments on commit 1e1d569

Please sign in to comment.