Skip to content

Commit

Permalink
Double rotation example with small changes in src
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed Jul 20, 2024
1 parent 3d52691 commit 9fa3354
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 46 deletions.
73 changes: 47 additions & 26 deletions examples/double_rotation/double_rotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ using Random
rng = Random.default_rng()
Random.seed!(rng, 666)

function run()
function run(;kappa=50., regs=regs, title="plot")

# Total time simulation
tspan = [0, 160.0]
Expand All @@ -31,10 +31,10 @@ times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples))
# Angular velocity
ω₀ = Δω₀ * π / 180.0
# Change point
τ₀ = 65.0
τ₀ = 70.0
# Angular momentum
L0 = ω₀ .* [1.0, 0.0, 0.0]
L1 = 0.5ω.* [0.0, 1/sqrt(2), 1/sqrt(2)]
L1 = 0.6ω.* [0.0, 1/sqrt(2), 1/sqrt(2)]

# Solver tolerances
reltol = 1e-12
Expand All @@ -49,46 +49,67 @@ function L_true(t::Float64; τ₀=τ₀, p=[L0, L1])
end

function true_rotation!(du, u, p, t)
L = L_true(t; τ₀=τ₀, p=p)
L = L_true(t; τ₀ = τ₀, p = p)
du .= cross(L, u)
end

prob = ODEProblem(true_rotation!, [0.0, 0.0, -1.0], tspan, [L0, L1])
true_sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samples)
true_sol = solve(prob, Tsit5(), reltol = reltol, abstol = abstol, saveat = times_samples)

# Add Fisher noise to true solution
X_noiseless = Array(true_sol)
X_true = X_noiseless + FisherNoise(kappa=50.)
X_true = X_noiseless + FisherNoise(kappa=kappa)

##############################################################
####################### Training ###########################
##############################################################

data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true)
data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true)

regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"),
# Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)]
# Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")]
# regs = nothing

params = SphereParameters(tmin=tspan[1], tmax=tspan[2],
reg=regs,
train_initial_condition=false,
multiple_shooting=true,
u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=600,
sensealg=GaussAdjoint(autojacvec=ReverseDiffVJP(true)))
params = SphereParameters(tmin = tspan[1], tmax = tspan[2],
reg = regs,
train_initial_condition = false,
multiple_shooting = false,
u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol,
niter_ADAM = 2000, niter_LBFGS = 1000,
sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true)))

results = train(data, params, rng, nothing)

##############################################################
###################### PyCall Plots #########################
##############################################################

plot_sphere(data, results, -20., 150., saveas="examples/double_rotation/plot_sphere.pdf", title="Double rotation")
plot_L(data, results, saveas="examples/double_rotation/plot_L.pdf", title="Double rotation")

end
run()
plot_sphere(data, results, -20., 125., saveas="examples/double_rotation/" * title * "_sphere.pdf", title="Double rotation") # , matplotlib_rcParams=Dict("font.size"=> 50))
plot_L(data, results, saveas="examples/double_rotation/" * title * "_L.pdf", title="Double rotation")

end # run

# Run different experiments

λ₀ = 0.1
λ₁ = 0.1
# λ₀ = 0.1
# λ₁ = 0.01
run(; kappa = 50.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)],
title = "plot_50_lambda$(λ₁)")

λ₀ = 0.1
λ₁ = 0.1
# λ₀ = 0.1
# λ₁ = 0.01
run(; kappa = 200.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=λ₀)],
title = "plot_200_lambda$(λ₁)")

λ₀ = 0.1
λ₁ = 0.1
# λ₀ = 0.1
# λ₁ = 0.01
run(; kappa = 1000.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=λ₀)],
title = "plot_1000_lambda$(λ₁)")
1 change: 1 addition & 0 deletions src/SphereUDE.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ include("train.jl")
include("plot.jl")

# Python libraries
const mpl_base::PyObject = isdefined(SphereUDE, :mpl_base) ? SphereUDE.mpl_base : PyNULL()
const mpl_colors::PyObject = isdefined(SphereUDE, :mpl_colors) ? SphereUDE.mpl_colors : PyNULL()
const mpl_colormap::PyObject = isdefined(SphereUDE, :mpl_colormap) ? SphereUDE.mpl_colormap : PyNULL()
const sns::PyObject = isdefined(SphereUDE, :sns) ? SphereUDE.sns : PyNULL()
Expand Down
14 changes: 12 additions & 2 deletions src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,22 @@ function plot_sphere(# ax::PyCall.PyObject,
central_latitude::Float64,
central_longitude::Float64;
saveas::Union{String, Nothing},
title::String)
title::String,
matplotlib_rcParams::Union{Dict, Nothing} = nothing)

# cmap = mpl_colormap.get_cmap("viridis")

plt.figure(figsize=(10,10))
ax = plt.axes(projection=ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude))

# Set default plot parameters.
# See https://matplotlib.org/stable/users/explain/customizing.html for customizable optionsz
if !isnothing(matplotlib_rcParams)
for (key, value) in matplotlib_rcParams
@warn "Setting Matplotlib parameters with rcParams currently not working. See following GitHub issue: https://github.com/JuliaPy/PyPlot.jl/issues/525"
mpl_base.rcParams[key] = value
end
end

# ax.coastlines()
ax.gridlines()
Expand All @@ -51,7 +61,7 @@ function plot_sphere(# ax::PyCall.PyObject,
linewidth=2, color="black",#cmap(norm(results.fit_times[i])),
transform = ccrs.Geodetic())
end
plt.title(title)
plt.title(title, fontsize=20)
if !isnothing(saveas)
plt.savefig(saveas, format="pdf")
end
Expand Down
5 changes: 3 additions & 2 deletions src/setup/config.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
export mpl_colormap, mpl_colormap, sns, ccrs, feature
export mpl_base, mpl_colormap, mpl_colormap, sns, ccrs, feature

function __init__()

try
copy!(mpl_base, pyimport("matplotlib"))
copy!(mpl_colors, pyimport("matplotlib.colors"))
copy!(mpl_colormap, pyimport("matplotlib.cm"))
copy!(sns, pyimport("seaborn"))
copy!(ccrs, pyimport("cartopy.crs"))
copy!(feature, pyimport("cartopy.feature"))
catch e
@warn "It looks like you have not installed and/or activated the virtual Python environment. \n
Please follow the guidelines in: https://github.com/facusapienza21/SphereUDE.jl#readme"
Please follow the guidelines in: https://github.com/ODINN-SciML/SphereUDE.jl"
@warn exception=(e, catch_backtrace())
end

Expand Down
39 changes: 26 additions & 13 deletions src/train.jl
Original file line number Diff line number Diff line change
@@ -1,17 +1,29 @@
export train

# For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid
function get_NN(params, rng, θ_trained)
# Define neural network
U = Lux.Chain(
Lux.Dense(1, 5, sigmoid), # explore discontinuity function for activation
Lux.Dense(5, 10, sigmoid),
Lux.Dense(10, 5, sigmoid),
# Lux.Dense(1, 5, relu_cap),
# Lux.Dense(5, 10, relu_cap),
# Lux.Dense(10, 5, relu_cap),
Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax))
)

# For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid
if isL1reg(params.reg)
@warn "[SphereUDE] Using ReLU activation functions for neural network due to L1 regularization."
U = Lux.Chain(
Lux.Dense(1, 5, sigmoid),
Lux.Dense(5, 10, sigmoid),
Lux.Dense(10, 5, sigmoid),
# Lux.Dense(1, 5, relu_cap),
# Lux.Dense(5, 10, relu_cap),
# Lux.Dense(10, 5, relu_cap),
Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax))
# Lux.Dense(5, 3, x -> relu_cap(x; ω₀=params.ωmax))
)
else
U = Lux.Chain(
Lux.Dense(1, 5, sigmoid),
Lux.Dense(5, 10, sigmoid),
Lux.Dense(10, 5, sigmoid),
Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax))
)
end
θ, st = Lux.setup(rng, U)
return U, θ, st
end
Expand Down Expand Up @@ -67,7 +79,8 @@ function train(data::AD,
end
sol = solve(_prob, params.solver, saveat=T,
abstol=params.abstol, reltol=params.reltol,
sensealg=params.sensealg)
sensealg=params.sensealg,
dtmin=1e-4 * (params.tmax - params.tmin), force_dtmin=true) # Force minimum step in case L(t) changes drastically due to bad behaviour of neural network
return Array(sol), sol.retcode
end

Expand Down Expand Up @@ -101,7 +114,7 @@ function train(data::AD,
for reg in params.reg
reg₀ = regularization.θ, reg)
l_reg += reg₀
loss_dict["Regularization (order=$(reg.order), power=$(reg.power)"] = reg₀
loss_dict["Regularization (order=$(reg.order), power=$(reg.power))"] = reg₀
end
end
return l_emp + l_reg, loss_dict
Expand Down Expand Up @@ -211,7 +224,7 @@ function train(data::AD,

if params.niter_LBFGS > 0
optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")
else
res2 = res1
Expand Down
6 changes: 5 additions & 1 deletion src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -56,5 +56,9 @@ Regularization information
order::I # Order of derivative
power::F # Power of the Euclidean norm
λ::F # Regularization hyperparameter
diff_mode::Union{Nothing, String} # AD differentiation mode used in regulatization
# AD differentiation mode used in regulatization
diff_mode::Union{Nothing, String} = nothing

# Include this in the constructor
# @assert (order == 0) || (!isnothing(diff_mode)) "Diffentiation methods needs to be provided for regularization with order larger than zero."
end
25 changes: 23 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ export cart2sph, sph2cart
export AbstractNoise, FisherNoise
export quadrature, central_fdm, complex_step_differentiation
export raise_warnings

export isL1reg

"""
sigmoid_cap(x; ω₀=1.0)
Expand Down Expand Up @@ -41,9 +41,12 @@ end
function relu_cap(x; ω₀=1.0)
min_value = - ω₀
max_value = + ω₀
return min_value + (max_value - min_value) * max(0.0, min(x, 1.0))
return relu_cap(x, min_value, max_value)
end

function relu_cap(x, min_value::Float64, max_value::Float64)
return min_value + (max_value - min_value) * max(0.0, min(x, 1.0))
end

"""
relu(x::Complex)
Expand All @@ -59,6 +62,11 @@ end
function relu_cap(z::Complex; ω₀=1.0)
min_value = - ω₀
max_value = + ω₀
return relu_cap(z, min_value, max_value)
# return min_value + (max_value - min_value) * relu(z - relu(z-1))
end

function relu_cap(z::Complex, min_value::Float64, max_value::Float64)
return min_value + (max_value - min_value) * relu(z - relu(z-1))
end

Expand Down Expand Up @@ -174,4 +182,17 @@ function raise_warnings(data::SphereData, params::SphereParameters)
# end
# end
nothing
end

"""
Function to check for the presence of L1 regularization in the loss function.
"""
function isL1reg(regs::Vector{R}) where {R <: AbstractRegularization}
for reg in regs
if reg.power == 1
return true
end
end
return false
end

0 comments on commit 9fa3354

Please sign in to comment.