-
-
Notifications
You must be signed in to change notification settings - Fork 211
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #3296 from hersle/generic_logged_functions
Generic mechanism for debugging/logging functions
- Loading branch information
Showing
5 changed files
with
93 additions
and
53 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
# Debugging | ||
|
||
Every (mortal) modeler writes models that contain mistakes or are susceptible to numerical errors in their hunt for the perfect model. | ||
Debugging such errors is part of the modeling process, and ModelingToolkit includes some functionality that helps with this. | ||
|
||
For example, consider an ODE model with "dangerous" functions (here `√`): | ||
|
||
```@example debug | ||
using ModelingToolkit, OrdinaryDiffEq | ||
using ModelingToolkit: t_nounits as t, D_nounits as D | ||
@variables u1(t) u2(t) u3(t) | ||
eqs = [D(u1) ~ -√(u1), D(u2) ~ -√(u2), D(u3) ~ -√(u3)] | ||
defaults = [u1 => 1.0, u2 => 2.0, u3 => 3.0] | ||
@named sys = ODESystem(eqs, t; defaults) | ||
sys = structural_simplify(sys) | ||
``` | ||
|
||
This problem causes the ODE solver to crash: | ||
|
||
```@repl debug | ||
prob = ODEProblem(sys, [], (0.0, 10.0), []); | ||
sol = solve(prob, Tsit5()); | ||
``` | ||
|
||
This suggests *that* something went wrong, but not exactly *what* went wrong and *where* it did. | ||
In such situations, the `debug_system` function is helpful: | ||
|
||
```@repl debug | ||
dsys = debug_system(sys; functions = [sqrt]); | ||
dprob = ODEProblem(dsys, [], (0.0, 10.0), []); | ||
dsol = solve(dprob, Tsit5()); | ||
``` | ||
|
||
Now we see that it crashed because `u1` decreased so much that it became negative and outside the domain of the `√` function. | ||
We could have figured that out ourselves, but it is not always so obvious for more complex models. | ||
|
||
```@docs | ||
debug_system | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,36 +1,44 @@ | ||
const LOGGED_FUN = Set([log, sqrt, (^), /, inv]) | ||
is_legal(::typeof(/), a, b) = is_legal(inv, b) | ||
is_legal(::typeof(inv), a) = !iszero(a) | ||
is_legal(::Union{typeof(log), typeof(sqrt)}, a) = a isa Complex || a >= zero(a) | ||
is_legal(::typeof(^), a, b) = a isa Complex || b isa Complex || isinteger(b) || a >= zero(a) | ||
|
||
struct LoggedFunctionException <: Exception | ||
msg::String | ||
end | ||
struct LoggedFun{F} | ||
f::F | ||
args::Any | ||
error_nonfinite::Bool | ||
end | ||
function LoggedFunctionException(lf::LoggedFun, args, msg) | ||
LoggedFunctionException( | ||
"Function $(lf.f)($(join(lf.args, ", "))) " * msg * " with input" * | ||
join("\n " .* string.(lf.args .=> args)) # one line for each "var => val" for readability | ||
) | ||
end | ||
Base.showerror(io::IO, err::LoggedFunctionException) = print(io, err.msg) | ||
Base.nameof(lf::LoggedFun) = nameof(lf.f) | ||
SymbolicUtils.promote_symtype(::LoggedFun, Ts...) = Real | ||
function (lf::LoggedFun)(args...) | ||
f = lf.f | ||
symbolic_args = lf.args | ||
if is_legal(f, args...) | ||
f(args...) | ||
else | ||
args_str = join(string.(symbolic_args .=> args), ", ", ", and ") | ||
throw(DomainError(args, "$(lf.f) errors with input(s): $args_str")) | ||
val = try | ||
lf.f(args...) # try to call with numerical input, as usual | ||
catch err | ||
throw(LoggedFunctionException(lf, args, "errors")) # Julia automatically attaches original error message | ||
end | ||
if lf.error_nonfinite && !isfinite(val) | ||
throw(LoggedFunctionException(lf, args, "output non-finite value $val")) | ||
end | ||
return val | ||
end | ||
|
||
function logged_fun(f, args...) | ||
function logged_fun(f, args...; error_nonfinite = true) # remember to update error_nonfinite in debug_system() docstring | ||
# Currently we don't really support complex numbers | ||
term(LoggedFun(f, args), args..., type = Real) | ||
term(LoggedFun(f, args, error_nonfinite), args..., type = Real) | ||
end | ||
|
||
debug_sub(eq::Equation) = debug_sub(eq.lhs) ~ debug_sub(eq.rhs) | ||
function debug_sub(ex) | ||
function debug_sub(eq::Equation, funcs; kw...) | ||
debug_sub(eq.lhs, funcs; kw...) ~ debug_sub(eq.rhs, funcs; kw...) | ||
end | ||
function debug_sub(ex, funcs; kw...) | ||
iscall(ex) || return ex | ||
f = operation(ex) | ||
args = map(debug_sub, arguments(ex)) | ||
f in LOGGED_FUN ? logged_fun(f, args...) : | ||
args = map(ex -> debug_sub(ex, funcs; kw...), arguments(ex)) | ||
f in funcs ? logged_fun(f, args...; kw...) : | ||
maketerm(typeof(ex), f, args, metadata(ex)) | ||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters