Skip to content

Commit

Permalink
Multiple shooting working once sensealg specified
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed May 3, 2024
1 parent 684621f commit 3d52691
Showing 1 changed file with 15 additions and 10 deletions.
25 changes: 15 additions & 10 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,8 +85,8 @@ function train(data::AD,

# Empirical error
if isnothing(data.kappas)
l_emp = 3.0 * mean(abs2.(u_ .- data.directions))
# The 3 is needed since the mean is computen on a 3xN matrix
l_emp = 3.0 * mean(abs2.(u_ .- data.directions))
# l_emp = 1 - 3.0 * mean(u_ .* data.directions)
else
l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1)
Expand All @@ -112,13 +112,10 @@ function train(data::AD,

# Empirical error
l_emp = 3.0 * mean(abs2.(pred .- data))
# The 3 is needed since the mean is computen on a 3xN matrix
# l_emp = 1 - 3.0 * mean(u_ .* data.directions)

# Regularization
l_reg = 0.0
if !isnothing(params.reg)
# for (order, power, λ) in params.reg
for reg in params.reg
reg₀ = regularization.θ, reg)
l_reg += reg₀
Expand All @@ -129,11 +126,19 @@ function train(data::AD,
end

# Define parameters for Multiple Shooting
group_size = 50
group_size = 10
continuity_term = 100

ps = ComponentArray(θ)
pd, pax = getdata(ps), getaxes(ps)
ps = ComponentArray(θ) # are these necesary?
pd, pax = getdata(ps), getaxes(ps)

function continuity_loss(u_pred, u_initial)
if !isapprox(norm(u_initial), 1.0, atol=1e-6) || !isapprox(norm(u_pred), 1.0, atol=1e-6)
@warn "Directions during multiple shooting are not in the sphere. Small deviations from unit norm observed:"
@show norm(u_initial), norm(u_pred)
end
return sum(abs2, u_pred - u_initial)
end

function loss_multiple_shooting::ComponentVector)

Expand All @@ -149,11 +154,11 @@ function train(data::AD,
p = β.θ)
end

return multiple_shoot.θ, data.directions, data.times, _prob, loss_function, Tsit5(),
group_size; continuity_term)
return multiple_shoot.θ, data.directions, data.times, _prob, loss_function, continuity_loss, params.solver,
group_size; continuity_term, sensealg=params.sensealg)
end

function regularization::ComponentVector, reg::AG; n_nodes=100) where{AG <: AbstractRegularization}
function regularization::ComponentVector, reg::AG; n_nodes=100) where {AG <: AbstractRegularization}

l_ = 0.0
if reg.order==0
Expand Down

0 comments on commit 3d52691

Please sign in to comment.