Skip to content

Commit

Permalink
[WIP] Implementation of multiple shooting (#81)
Browse files Browse the repository at this point in the history
  • Loading branch information
JordiBolibar authored May 3, 2024
1 parent 1ce6489 commit 1cf3d40
Show file tree
Hide file tree
Showing 7 changed files with 78 additions and 18 deletions.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.1"
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
Expand All @@ -17,6 +18,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
9 changes: 6 additions & 3 deletions examples/double_rotation/double_rotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,22 @@ X_true = X_noiseless + FisherNoise(kappa=50.)

data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true)

regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"),
# Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")]
# regs = nothing

params = SphereParameters(tmin=tspan[1], tmax=tspan[2],
reg=regs,
train_initial_condition=false,
multiple_shooting=true,
u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=600,
sensealg=GaussAdjoint(autojacvec=ReverseDiffVJP(true)))

results = train(data, params, rng, nothing; train_initial_condition=false)
results = train(data, params, rng, nothing)

##############################################################
###################### PyCall Plots #########################
Expand Down
7 changes: 4 additions & 3 deletions src/SphereUDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ using Base: @kwdef
# training
using LinearAlgebra, Statistics, Distributions
using FastGaussQuadrature
using Lux, Zygote
using Lux, Zygote, DiffEqFlux
using ChainRules: @ignore_derivatives
using OrdinaryDiffEq
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays: ComponentVector
using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms
using ComponentArrays
using PyPlot, PyCall
using PrettyTables

Expand Down
70 changes: 59 additions & 11 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ Training function.
function train(data::AD,
params::AP,
rng,
θ_trained=[];
train_initial_condition::Bool=false) where {AD <: AbstractData, AP <: AbstractParameters}
θ_trained=[]) where {AD <: AbstractData, AP <: AbstractParameters}

# Raise warnings
raise_warnings(data::AD, params::AP)

U, θ, st = get_NN(params, rng, θ_trained)

# Set component vector for Optimization
if train_initial_condition
if params.train_initial_condition
β = ComponentVector{Float64}=θ, u0=params.u0)
else
β = ComponentVector{Float64}=θ)
Expand All @@ -57,8 +56,8 @@ function train(data::AD,
prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], β.θ)

function predict::ComponentVector; T=data.times)
if train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm.u0), # We enforced the norm=1 condition again here
if params.train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm.u0), # We enforce the norm=1 condition again here
tspan=(min(T[1], params.tmin), max(T[end], params.tmax)),
p = β.θ)
else
Expand Down Expand Up @@ -108,7 +107,53 @@ function train(data::AD,
return l_emp + l_reg, loss_dict
end

function regularization::ComponentVector, reg::AbstractRegularization; n_nodes=100)
# Loss function to be called for multiple shooting
function loss_function(data, pred)

# Empirical error
l_emp = 3.0 * mean(abs2.(pred .- data))
# The 3 is needed since the mean is computen on a 3xN matrix
# l_emp = 1 - 3.0 * mean(u_ .* data.directions)

# Regularization
l_reg = 0.0
if !isnothing(params.reg)
# for (order, power, λ) in params.reg
for reg in params.reg
reg₀ = regularization.θ, reg)
l_reg += reg₀
end
end

return l_emp + l_reg
end

# Define parameters for Multiple Shooting
group_size = 50
continuity_term = 100

ps = ComponentArray(θ)
pd, pax = getdata(ps), getaxes(ps)

function loss_multiple_shooting::ComponentVector)

ps = ComponentArray.θ, pax)

if params.train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm.u0), # We enforce the norm=1 condition again here
tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)),
p = β.θ)
else
_prob = remake(prob_nn, u0=params.u0,
tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)),
p = β.θ)
end

return multiple_shoot.θ, data.directions, data.times, _prob, loss_function, Tsit5(),
group_size; continuity_term)
end

function regularization::ComponentVector, reg::AG; n_nodes=100) where{AG <: AbstractRegularization}

l_ = 0.0
if reg.order==0
Expand Down Expand Up @@ -143,17 +188,20 @@ function train(data::AD,
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
if train_initial_condition
if params.train_initial_condition
p.u0 ./= norm(p.u0)
end
return false
end

# Dispatch the right loss function
f_loss = params.multiple_shooting ? loss_multiple_shooting : loss

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, β) -> (first loss)(x), adtype)
optf = Optimization.OptimizationFunction((x, β) -> (first f_loss)(x), adtype)
optprob = Optimization.OptimizationProblem(optf, β)

res1 = Optimization.solve(optprob, ADAM(0.002), callback=callback, maxiters=params.niter_ADAM)
res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=false)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

if params.niter_LBFGS > 0
Expand All @@ -169,14 +217,14 @@ function train(data::AD,
θ_trained = β_trained.θ

# Optimized initial condition
if train_initial_condition
if params.train_initial_condition
u0_trained = Array(β_trained.u0) # β.u0 is a view type
else
u0_trained = params.u0
end

# Final Fit
fit_times = collect(range(params.tmin,params.tmax, length=200))
fit_times = collect(range(params.tmin,params.tmax, length=length(data.times)))
fit_directions, _ = predict(β_trained, T=fit_times)

# Recover final balance between different terms involved in the loss function to assess hyperparameter selection.
Expand Down
2 changes: 2 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Training parameters
u0::Union{Vector{F}, Nothing}
ωmax::F
reg::Union{Nothing, Array}
train_initial_condition::Bool
multiple_shooting::Bool
niter_ADAM::I
niter_LBFGS::I
reltol::F
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ end
"""
complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10)
Manual implementation of comple-step differentiation
Manual implementation of complex-step differentiation
"""
function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10)
return imag(f(x + ϵ * im)) / ϵ
Expand Down
4 changes: 4 additions & 0 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ function test_param_constructor()
u0=[0. ,0. ,1.],
ωmax=1.0,
reg=nothing,
train_initial_condition=false,
multiple_shooting=true,
niter_ADAM=1000, niter_LBFGS=300,
reltol=1e6, abstol=1e-6)

Expand All @@ -26,6 +28,8 @@ function test_param_constructor()
u0=[0. ,0. ,1.],
ωmax=1.0,
reg=[reg1, reg2],
train_initial_condition=false,
multiple_shooting=true,
niter_ADAM=1000, niter_LBFGS=300,
reltol=1e6, abstol=1e-6)

Expand Down

0 comments on commit 1cf3d40

Please sign in to comment.