Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple shooting #81

Merged
merged 2 commits into from
May 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ version = "0.1.1"
[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
Expand All @@ -17,6 +18,7 @@ Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
Expand Down
9 changes: 6 additions & 3 deletions examples/double_rotation/double_rotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,19 +66,22 @@ X_true = X_noiseless + FisherNoise(kappa=50.)

data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true)

regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"),
# Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")]
# regs = nothing

params = SphereParameters(tmin=tspan[1], tmax=tspan[2],
reg=regs,
train_initial_condition=false,
multiple_shooting=true,
u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=600,
sensealg=GaussAdjoint(autojacvec=ReverseDiffVJP(true)))

results = train(data, params, rng, nothing; train_initial_condition=false)
results = train(data, params, rng, nothing)

##############################################################
###################### PyCall Plots #########################
Expand Down
7 changes: 4 additions & 3 deletions src/SphereUDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@ using Base: @kwdef
# training
using LinearAlgebra, Statistics, Distributions
using FastGaussQuadrature
using Lux, Zygote
using Lux, Zygote, DiffEqFlux
using ChainRules: @ignore_derivatives
using OrdinaryDiffEq
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays: ComponentVector
using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms
using ComponentArrays
using PyPlot, PyCall
using PrettyTables

Expand Down
70 changes: 59 additions & 11 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,15 @@ Training function.
function train(data::AD,
params::AP,
rng,
θ_trained=[];
train_initial_condition::Bool=false) where {AD <: AbstractData, AP <: AbstractParameters}
θ_trained=[]) where {AD <: AbstractData, AP <: AbstractParameters}

# Raise warnings
raise_warnings(data::AD, params::AP)

U, θ, st = get_NN(params, rng, θ_trained)

# Set component vector for Optimization
if train_initial_condition
if params.train_initial_condition
β = ComponentVector{Float64}(θ=θ, u0=params.u0)
else
β = ComponentVector{Float64}(θ=θ)
Expand All @@ -57,8 +56,8 @@ function train(data::AD,
prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], β.θ)

function predict(β::ComponentVector; T=data.times)
if train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforced the norm=1 condition again here
if params.train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforce the norm=1 condition again here
tspan=(min(T[1], params.tmin), max(T[end], params.tmax)),
p = β.θ)
else
Expand Down Expand Up @@ -108,7 +107,53 @@ function train(data::AD,
return l_emp + l_reg, loss_dict
end

function regularization(θ::ComponentVector, reg::AbstractRegularization; n_nodes=100)
# Loss function to be called for multiple shooting
function loss_function(data, pred)

# Empirical error
l_emp = 3.0 * mean(abs2.(pred .- data))
# The 3 is needed since the mean is computen on a 3xN matrix
# l_emp = 1 - 3.0 * mean(u_ .* data.directions)

# Regularization
l_reg = 0.0
if !isnothing(params.reg)
# for (order, power, λ) in params.reg
for reg in params.reg
reg₀ = regularization(β.θ, reg)
l_reg += reg₀
end
end

return l_emp + l_reg
end

# Define parameters for Multiple Shooting
group_size = 50
continuity_term = 100

ps = ComponentArray(θ)
pd, pax = getdata(ps), getaxes(ps)

function loss_multiple_shooting(β::ComponentVector)

ps = ComponentArray(β.θ, pax)

if params.train_initial_condition
_prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforce the norm=1 condition again here
tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)),
p = β.θ)
else
_prob = remake(prob_nn, u0=params.u0,
tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)),
p = β.θ)
end

return multiple_shoot(β.θ, data.directions, data.times, _prob, loss_function, Tsit5(),
group_size; continuity_term)
end

function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where{AG <: AbstractRegularization}

l_ = 0.0
if reg.order==0
Expand Down Expand Up @@ -143,17 +188,20 @@ function train(data::AD,
if length(losses) % 50 == 0
println("Current loss after $(length(losses)) iterations: $(losses[end])")
end
if train_initial_condition
if params.train_initial_condition
p.u0 ./= norm(p.u0)
end
return false
end

# Dispatch the right loss function
f_loss = params.multiple_shooting ? loss_multiple_shooting : loss

adtype = Optimization.AutoZygote()
optf = Optimization.OptimizationFunction((x, β) -> (first ∘ loss)(x), adtype)
optf = Optimization.OptimizationFunction((x, β) -> (first ∘ f_loss)(x), adtype)
optprob = Optimization.OptimizationProblem(optf, β)

res1 = Optimization.solve(optprob, ADAM(0.002), callback=callback, maxiters=params.niter_ADAM)
res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=false)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

if params.niter_LBFGS > 0
Expand All @@ -169,14 +217,14 @@ function train(data::AD,
θ_trained = β_trained.θ

# Optimized initial condition
if train_initial_condition
if params.train_initial_condition
u0_trained = Array(β_trained.u0) # β.u0 is a view type
else
u0_trained = params.u0
end

# Final Fit
fit_times = collect(range(params.tmin,params.tmax, length=200))
fit_times = collect(range(params.tmin,params.tmax, length=length(data.times)))
fit_directions, _ = predict(β_trained, T=fit_times)

# Recover final balance between different terms involved in the loss function to assess hyperparameter selection.
Expand Down
2 changes: 2 additions & 0 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ Training parameters
u0::Union{Vector{F}, Nothing}
ωmax::F
reg::Union{Nothing, Array}
train_initial_condition::Bool
multiple_shooting::Bool
niter_ADAM::I
niter_LBFGS::I
reltol::F
Expand Down
2 changes: 1 addition & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ end
"""
complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10)

Manual implementation of comple-step differentiation
Manual implementation of complex-step differentiation
"""
function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10)
return imag(f(x + ϵ * im)) / ϵ
Expand Down
4 changes: 4 additions & 0 deletions test/constructors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ function test_param_constructor()
u0=[0. ,0. ,1.],
ωmax=1.0,
reg=nothing,
train_initial_condition=false,
multiple_shooting=true,
niter_ADAM=1000, niter_LBFGS=300,
reltol=1e6, abstol=1e-6)

Expand All @@ -26,6 +28,8 @@ function test_param_constructor()
u0=[0. ,0. ,1.],
ωmax=1.0,
reg=[reg1, reg2],
train_initial_condition=false,
multiple_shooting=true,
niter_ADAM=1000, niter_LBFGS=300,
reltol=1e6, abstol=1e-6)

Expand Down
Loading