Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Generic mechanism for debugging/logging functions #3296

Merged
merged 10 commits into from
Jan 18, 2025
1 change: 1 addition & 0 deletions docs/pages.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ pages = [
"basics/InputOutput.md",
"basics/MTKLanguage.md",
"basics/Validation.md",
"basics/Debugging.md",
"basics/DependencyGraphs.md",
"basics/Precompilation.md",
"basics/FAQ.md"],
Expand Down
40 changes: 40 additions & 0 deletions docs/src/basics/Debugging.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# Debugging

Every (mortal) modeler writes models that contain mistakes or are susceptible to numerical errors in their hunt for the perfect model.
Debugging such errors is part of the modeling process, and ModelingToolkit includes some functionality that helps with this.

For example, consider an ODE model with "dangerous" functions (here `√`):

```@example debug
using ModelingToolkit, OrdinaryDiffEq
using ModelingToolkit: t_nounits as t, D_nounits as D

@variables u1(t) u2(t) u3(t)
eqs = [D(u1) ~ -√(u1), D(u2) ~ -√(u2), D(u3) ~ -√(u3)]
defaults = [u1 => 1.0, u2 => 2.0, u3 => 3.0]
@named sys = ODESystem(eqs, t; defaults)
sys = structural_simplify(sys)
```

This problem causes the ODE solver to crash:

```@repl debug
prob = ODEProblem(sys, [], (0.0, 10.0), []);
sol = solve(prob, Tsit5());
```

This suggests *that* something went wrong, but not exactly *what* went wrong and *where* it did.
In such situations, the `debug_system` function is helpful:

```@repl debug
dsys = debug_system(sys; functions = [sqrt]);
dprob = ODEProblem(dsys, [], (0.0, 10.0), []);
dsol = solve(dprob, Tsit5());
```

Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function.
We could have figured that out ourselves, but it is not always so obvious for more complex models.

```@docs
debug_system
```
46 changes: 27 additions & 19 deletions src/debugging.jl
Original file line number Diff line number Diff line change
@@ -1,36 +1,44 @@
const LOGGED_FUN = Set([log, sqrt, (^), /, inv])
is_legal(::typeof(/), a, b) = is_legal(inv, b)
is_legal(::typeof(inv), a) = !iszero(a)
is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a)
is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a)

struct LoggedFunctionException <: Exception
msg::String
end
struct LoggedFun{F}
f::F
args::Any
error_nonfinite::Bool
end
function LoggedFunctionException(lf::LoggedFun, args, msg)
LoggedFunctionException(
"Function $(lf.f)($(join(lf.args, ", "))) " * msg * " with input" *
join("\n " .* string.(lf.args .=> args)) # one line for each "var => val" for readability
)
end
Base.showerror(io::IO, err::LoggedFunctionException) = print(io, err.msg)
Base.nameof(lf::LoggedFun) = nameof(lf.f)
SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real
function (lf::LoggedFun)(args...)
f = lf.f
symbolic_args = lf.args
if is_legal(f, args...)
f(args...)
else
args_str = join(string.(symbolic_args .=> args), ", ", ", and ")
throw(DomainError(args, "$(lf.f) errors with input(s): $args_str"))
val = try
lf.f(args...) # try to call with numerical input, as usual
catch err
throw(LoggedFunctionException(lf, args, "errors")) # Julia automatically attaches original error message
end
if lf.error_nonfinite && !isfinite(val)
throw(LoggedFunctionException(lf, args, "output non-finite value $val"))
end
return val
end

function logged_fun(f, args...)
function logged_fun(f, args...; error_nonfinite = true) # remember to update error_nonfinite in debug_system() docstring
# Currently we don't really support complex numbers
term(LoggedFun(f, args), args..., type = Real)
term(LoggedFun(f, args, error_nonfinite), args..., type = Real)
end

debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs)
function debug_sub(ex)
function debug_sub(eq::Equation, funcs; kw...)
debug_sub(eq.lhs, funcs; kw...) ~ debug_sub(eq.rhs, funcs; kw...)
end
function debug_sub(ex, funcs; kw...)
iscall(ex) || return ex
f = operation(ex)
args = map(debug_sub, arguments(ex))
f in LOGGED_FUN ? logged_fun(f, args...) :
args = map(ex -> debug_sub(ex, funcs; kw...), arguments(ex))
f in funcs ? logged_fun(f, args...; kw...) :
maketerm(typeof(ex), f, args, metadata(ex))
end
38 changes: 19 additions & 19 deletions src/systems/abstractsystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2260,37 +2260,37 @@ macro mtkbuild(exprs...)
end

"""
$(SIGNATURES)
debug_system(sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], error_nonfinite = true)

Replace functions with singularities with a function that errors with symbolic
information. E.g.
Wrap `functions` in `sys` so any error thrown in them shows helpful symbolic-numeric
information about its input. If `error_nonfinite`, functions that output nonfinite
values (like `Inf` or `NaN`) also display errors, even though the raw function itself
does not throw an exception (like `1/0`). For example:

```julia-repl
julia> sys = debug_system(sys);

julia> sys = complete(sys);
julia> sys = debug_system(complete(sys))

julia> prob = ODEProblem(sys, [], (0, 1.0));
julia> prob = ODEProblem(sys, [0.0, 2.0], (0.0, 1.0))

julia> du = zero(prob.u0);

julia> prob.f(du, prob.u0, prob.p, 0.0)
ERROR: DomainError with (-1.0,):
log errors with input(s): -cos(Q(t)) => -1.0
Stacktrace:
[1] (::ModelingToolkit.LoggedFun{typeof(log)})(args::Float64)
...
julia> prob.f(prob.u0, prob.p, 0.0)
ERROR: Function /(1, sin(P(t))) output non-finite value Inf with input
1 => 1
sin(P(t)) => 0.0
```
"""
function debug_system(sys::AbstractSystem)
function debug_system(
sys::AbstractSystem; functions = [log, sqrt, (^), /, inv, asin, acos], kw...)
hersle marked this conversation as resolved.
Show resolved Hide resolved
if !(functions isa Set)
functions = Set(functions) # more efficient "in" lookup
end
if has_systems(sys) && !isempty(get_systems(sys))
error("debug_system only works on systems with no sub-systems!")
error("debug_system(sys) only works on systems with no sub-systems! Consider flattening it with flatten(sys) or structural_simplify(sys) first.")
end
if has_eqs(sys)
@set! sys.eqs = debug_sub.(equations(sys))
@set! sys.eqs = debug_sub.(equations(sys), Ref(functions); kw...)
end
if has_observed(sys)
@set! sys.observed = debug_sub.(observed(sys))
@set! sys.observed = debug_sub.(observed(sys), Ref(functions); kw...)
end
return sys
end
Expand Down
21 changes: 6 additions & 15 deletions test/odesystem.jl
Original file line number Diff line number Diff line change
Expand Up @@ -931,22 +931,13 @@ testdict = Dict([:name => "test"])
@named sys = ODESystem(eqs, t, metadata = testdict)
@test get_metadata(sys) == testdict

@variables P(t)=0 Q(t)=2
∂t = D

eqs = [∂t(Q) ~ 1 / sin(P)
∂t(P) ~ log(-cos(Q))]
@variables P(t)=NaN Q(t)=NaN
eqs = [D(Q) ~ 1 / sin(P), D(P) ~ log(-cos(Q))]
@named sys = ODESystem(eqs, t, [P, Q], [])
sys = complete(debug_system(sys));
prob = ODEProblem(sys, [], (0, 1.0));
du = zero(prob.u0);
if VERSION < v"1.8"
@test_throws DomainError prob.f(du, [1, 0], prob.p, 0.0)
@test_throws DomainError prob.f(du, [0, 2], prob.p, 0.0)
else
@test_throws "-cos(Q(t))" prob.f(du, [1, 0], prob.p, 0.0)
@test_throws "sin(P(t))" prob.f(du, [0, 2], prob.p, 0.0)
end
sys = complete(debug_system(sys))
prob = ODEProblem(sys, [], (0.0, 1.0))
@test_throws "log(-cos(Q(t))) errors" prob.f([1, 0], prob.p, 0.0)
@test_throws "/(1, sin(P(t))) output non-finite value" prob.f([0, 2], prob.p, 0.0)

let
@variables x(t) = 1
Expand Down
Loading