diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 60ecac4..2aefe6e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -3,10 +3,12 @@ on: push: branches: - main + - up-lux tags: ['*'] pull_request: branches: - main + - up-lux workflow_dispatch: concurrency: # Skip intermediate builds: always. @@ -24,7 +26,7 @@ jobs: fail-fast: false matrix: version: - - '1.9' + - '1' # - 'nightly' python: - 3.9 diff --git a/Project.toml b/Project.toml index fe0c86b..98d68a1 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SphereUDE" uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d" authors = ["Facundo Sapienza ", "Jordi Bolibar "] -version = "0.1.1" +version = "0.1.2" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" @@ -12,44 +12,52 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BenchmarkTools = "1" ComponentArrays = "0.15" +DiffEqFlux = "4" Distributions = "0.25" Infiltrator = "1.2" -Lux = "<0.5.49" +Lux = "1.0" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" -OrdinaryDiffEq = "5, 6" +OrdinaryDiffEqCore = "1.6.0" +OrdinaryDiffEqTsit5 = "1.1.0" PyCall = "1.9" PyPlot = "2.11" Revise = "3.1" SciMLSensitivity = "7.20" Statistics = "1" Zygote = "0.6" -julia = "1.7" +julia = "1.10" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Random"] diff --git a/examples/Torsvik_2012/APWP-Torsvik.jl b/examples/Torsvik_2012/APWP-Torsvik.jl new file mode 100644 index 0000000..dbb30fb --- /dev/null +++ b/examples/Torsvik_2012/APWP-Torsvik.jl @@ -0,0 +1,83 @@ +using Pkg; Pkg.activate(".") +using Revise +using Lux + +using LinearAlgebra, Statistics, Distributions +using SciMLSensitivity +# using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 +using Optimization, OptimizationOptimisers, OptimizationOptimJL + +using SphereUDE + +# Random seed +using Random +rng = Random.default_rng() +Random.seed!(rng, 613) + +using DataFrames, CSV +using Serialization, JLD2 + +df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame, delim=",") + +# Filter the plates that were once part of the supercontinent Gondwana + +Gondwana = ["Amazonia", "Parana", "Colorado", "Southern_Africa", + "East_Antarctica", "Madagascar", "Patagonia", "Northeast_Africa", + "Northwest_Africa", "Somalia", "Arabia", "East_Gondwana"] + +df = filter(row -> row.Plate ∈ Gondwana, df) +df.Times = df.Age .+= rand(sampler(Normal(0,0.1)), nrow(df)) # Needs to fix this! + +df = sort(df, :Times) +times = df.Times + +# Fill missing values +df.RLat .= coalesce.(df.RLat, df.Lat) +df.RLon .= coalesce.(df.RLon, df.Lon) + +X = sph2cart(Matrix(df[:,["RLat","RLon"]])'; radians=false) + +# Retrieve uncertanties from poles and convert α95 into κ +kappas = (140.0 ./ df.a95).^2 + +data = SphereData(times=times, directions=X, kappas=kappas, L=nothing) + +# Training + +# Expected maximum angular deviation in one unit of time (degrees) +Δω₀ = 1.5 +# Angular velocity +ω₀ = Δω₀ * π / 180.0 + +tspan = [times[begin], times[end]] + +params = SphereParameters(tmin = tspan[1], tmax = tspan[2], + reg = [Regularization(order=1, power=2.0, λ=1e5, diff_mode=FiniteDifferences(1e-4))], + # reg = nothing, + pretrain = false, + u0 = [0.0, 0.0, -1.0], ωmax = ω₀, + reltol = 1e-6, abstol = 1e-6, + niter_ADAM = 5000, niter_LBFGS = 5000, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) + + +init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) +init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) + +# Customized neural network to similate weighted moving window in L +U = Lux.Chain( + Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), + Lux.Dense(200,10, gelu), + Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) +) + +results = train(data, params, rng, nothing, U) +results_dict = convert2dict(data, results) + +# JLD2.@save "examples/Torsvik_2012/results/data.jld2" data +# JLD2.@save "examples/Torsvik_2012/results/results.jld2" results +JLD2.@save "examples/Torsvik_2012/results/results_dict.jld2" results_dict + + +plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation") +plot_L(data, results, saveas="examples/Torsvik_2012/plots/plot_L.pdf", title="Double rotation") diff --git a/examples/benchmark.jl b/examples/benchmark.jl index e375095..d224657 100644 --- a/examples/benchmark.jl +++ b/examples/benchmark.jl @@ -45,6 +45,7 @@ function benchmark() #solvers = [BS5(), OwrenZen5(), OwrenZen3(), BS3(), Tsit5()] solvers = [BS5(), Tsit5()] + # Different sensealgs available at: https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))] #sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)), InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))] diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 597d53a..0e4c5cb 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -2,9 +2,10 @@ using Pkg; Pkg.activate(".") using Revise using LinearAlgebra, Statistics, Distributions -using OrdinaryDiffEq using SciMLSensitivity +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using Optimization, OptimizationOptimisers, OptimizationOptimJL +using Lux using SphereUDE @@ -37,8 +38,8 @@ L0 = ω₀ .* [1.0, 0.0, 0.0] L1 = 0.6ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)] # Solver tolerances -reltol = 1e-12 -abstol = 1e-12 +reltol = 1e-6 +abstol = 1e-6 function L_true(t::Float64; τ₀=τ₀, p=[L0, L1]) if t < τ₀ @@ -71,10 +72,19 @@ params = SphereParameters(tmin = tspan[1], tmax = tspan[2], train_initial_condition = false, multiple_shooting = false, u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol, - niter_ADAM = 2000, niter_LBFGS = 1000, + niter_ADAM = 5000, niter_LBFGS = 5000, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) -results = train(data, params, rng, nothing) +init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) +init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) + +U = Lux.Chain( + Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), + Lux.Dense(200,10, gelu), + Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) +) + +results = train(data, params, rng, nothing, U) ############################################################## ###################### PyCall Plots ######################### @@ -87,26 +97,75 @@ end # run # Run different experiments -λ₀ = 0.1 -λ₁ = 0.001 -run(; kappa = 50., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="FD"), - Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)], - title = "plots/plot_50_lambda$(λ₁)") +### Finite differeces + +# run(; kappa = 50., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=10.0)], +# title = "plots/FD_plot_50") + +# run(; kappa = 200., +# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/FD_plot_200") + + +# run(; kappa = 1000., +# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/FD_plot_1000") + + +# Complex Step Method + +# run(; kappa = 50., +# regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_50") + +# run(; kappa = 200., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_200") + -λ₀ = 0.1 -λ₁ = 0.1 +# run(; kappa = 1000., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_1000") + + + +### AD + +run(; kappa = 50., + regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=LuxNestedAD())], %, + # Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_50") run(; kappa = 200., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), - Regularization(order=0, power=2.0, λ=λ₀)], - title = "plots/plot_200_lambda$(λ₁)") + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()), + Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_200") + +run(; kappa = 1000., + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()), + Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_1000") + + +### no first-order regularization + +# run(; kappa = 50., +# regs = [Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/None_plot_50") + -λ₀ = 0.1 -λ₁ = 0.1 +# run(; kappa = 200., +# regs = [Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/None_plot_200") run(; kappa = 1000., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), - Regularization(order=0, power=2.0, λ=λ₀)], - title = "plots/plot_1000_lambda$(λ₁)") \ No newline at end of file + regs = nothing, + title = "plots/_None_plot_1000") \ No newline at end of file diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index 7a49453..6003420 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -1,23 +1,19 @@ __precompile__() module SphereUDE -# types -using Base: @kwdef # utils # training using LinearAlgebra, Statistics, Distributions using FastGaussQuadrature using Lux, Zygote, DiffEqFlux using ChainRules: @ignore_derivatives -using OrdinaryDiffEq -using SciMLSensitivity -using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 +using SciMLSensitivity, ForwardDiff +using Optimization, OptimizationOptimisers, OptimizationOptimJL +using OptimizationPolyalgorithms, LineSearches using ComponentArrays using PyPlot, PyCall -using PrettyTables - -# Testing double-differentiation -# using BatchedRoutines +using PrettyTables, Printf # Debugging using Infiltrator diff --git a/src/plot.jl b/src/plot.jl index c33df4e..ddfa06d 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -40,7 +40,7 @@ function plot_sphere(# ax::PyCall.PyObject, end end - # ax.coastlines() + ax.coastlines() ax.gridlines() ax.set_global() diff --git a/src/train.jl b/src/train.jl index 3d48eac..23f120e 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,31 +1,30 @@ export train function get_NN(params, rng, θ_trained) - # Define neural network - + # Define default neural network + # For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid if isL1reg(params.reg) @warn "[SphereUDE] Using ReLU activation functions for neural network due to L1 regularization." U = Lux.Chain( Lux.Dense(1, 5, sigmoid), Lux.Dense(5, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), Lux.Dense(10, 5, sigmoid), - # Lux.Dense(1, 5, relu_cap), - # Lux.Dense(5, 10, relu_cap), - # Lux.Dense(10, 5, relu_cap), - Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) - # Lux.Dense(5, 3, x -> relu_cap(x; ω₀=params.ωmax)) + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) else U = Lux.Chain( - Lux.Dense(1, 5, sigmoid), - Lux.Dense(5, 10, sigmoid), - Lux.Dense(10, 5, sigmoid), - Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) + Lux.Dense(1, 5, gelu), + Lux.Dense(5, 10, gelu), + Lux.Dense(10, 10, gelu), + Lux.Dense(10, 10, gelu), + Lux.Dense(10, 5, gelu), + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) - end - θ, st = Lux.setup(rng, U) - return U, θ, st + end + return U end """ @@ -35,6 +34,7 @@ Predict value of rotation given by L given by the neural network. """ function predict_L(t, NN, θ, st) return NN([t], θ, st)[1] + # return 0.2 * ( NN([t-1.0], θ, st)[1] + NN([t-0.5], θ, st)[1] + NN([t], θ, st)[1] + NN([t+0.5], θ, st)[1] + NN([t+1.0], θ, st)[1] ) end """ @@ -45,12 +45,20 @@ Training function. function train(data::AD, params::AP, rng, - θ_trained=[]) where {AD <: AbstractData, AP <: AbstractParameters} + θ_trained=[], + model::Union{Chain, Nothing}=nothing) where {AD <: AbstractData, AP <: AbstractParameters} # Raise warnings raise_warnings(data::AD, params::AP) - - U, θ, st = get_NN(params, rng, θ_trained) + + if isnothing(model) + U = get_NN(params, rng, θ_trained) + else + U = model + end + θ, st = Lux.setup(rng, U) + # Make it a stateful layer (this I don't know where is best to add it, it repeats) + smodel = StatefulLuxLayer{true}(U, θ, st) # Set component vector for Optimization if params.train_initial_condition @@ -84,44 +92,120 @@ function train(data::AD, return Array(sol), sol.retcode end + ##### Definition of loss functions to be used ##### + + """ + General Loss Function + """ function loss(β::ComponentVector) - u_, retcode = predict(β) + + # Record the value of each individual loss to the total loss function for hyperparameter selection. + loss_dict = Dict() + + l_emp = loss_empirical(β) + + loss_dict["Empirical"] = l_emp + + # Regularization + l_reg = 0.0 + if !isnothing(params.reg) + for reg in params.reg + reg₀ = regularization(β.θ, reg) + l_reg += reg₀ + loss_dict["Regularization (order=$(reg.order), power=$(reg.power))"] = reg₀ + end + end + return l_emp + l_reg, loss_dict + end + """ + Empirical loss function + """ + function loss_empirical(β::ComponentVector) + + u_, retcode = predict(β) + # If numerical integration fails or bad choice of parameter, return infinity if retcode != :Success - @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter." + @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter. This can be due to just a bad initial condition of the neural network, so it is worth changing the randon number used for initialization. " return Inf end - # Record the value of each individual loss to the total loss function for hyperparameter selection. - loss_dict = Dict() - # Empirical error if isnothing(data.kappas) # The 3 is needed since the mean is computen on a 3xN matrix - l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) + return 3.0 * mean(abs2.(u_ .- data.directions)) # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else - l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) + return mean(data.kappas .* sum(abs2.(u_ .- data.directions), dims=1)) # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) end - loss_dict["Empirical"] = l_emp + + end + + """ + Regularization + """ + function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where {AG <: AbstractRegularization} - # Regularization - l_reg = 0.0 - if !isnothing(params.reg) - # for (order, power, λ) in params.reg - for reg in params.reg - reg₀ = regularization(β.θ, reg) - l_reg += reg₀ - loss_dict["Regularization (order=$(reg.order), power=$(reg.power))"] = reg₀ - end + l_ = 0.0 + if reg.order==0 + + l_ += quadrature(t -> norm(predict_L(t, U, θ, st))^reg.power, params.tmin, params.tmax, n_nodes) + + elseif reg.order==1 + + if typeof(reg.diff_mode) <: LuxNestedAD + # Automatic Differentiation + nodes, weights = quadrature(params.tmin, params.tmax, n_nodes) + + if reg.diff_mode.method == "ForwardDiff" + # Jac = ForwardDiff.jacobian(smodel, reshape(nodes, 1, n_nodes)) + Jac = batched_jacobian(smodel, AutoForwardDiff(), reshape(nodes, 1, n_nodes)) + elseif reg.diff_mode.method == "Zygote" + # This can also be done with Zygote in reverse mode + # Jac = Zygote.jacobian(smodel, reshape(nodes, 1, n_nodes))[1] + Jac = batched_jacobian(smodel, AutoZygote(), reshape(nodes, 1, n_nodes)) + else + throw("Method for AD backend no implemented.") + end + + # Compute the final agregation to the loss + l_ += sum([weights[j] * norm(Jac[:,1,j])^reg.power for j in 1:n_nodes]) + + # Test every a few iterations that AD is working properly + if rand(Bernoulli(0.001)) + l_AD = sum([weights[j] * norm(Jac[:,1,j])^reg.power for j in 1:n_nodes]) + l_FD = quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, 1e-5))^reg.power, params.tmin, params.tmax, n_nodes) + if abs(l_AD - l_FD) < 1e-2 * l_FD + @warn "[SphereUDE] Nested AD is giving significant different results than Finite Differences." + @printf "[SphereUDE] Regularization with AD: %.9f vs %.9f using Finite Differences" l_AD l_FD + end + end + + elseif typeof(reg.diff_mode) <: FiniteDifferences + # Finite differences + l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, reg.diff_mode.ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + + elseif typeof(reg.diff_mode) <: ComplexStepDifferentiation + # Complex step differentiation + l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t, reg.diff_mode.ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + + else + throw("Method not implemented.") + end + + else + throw("Method not implemented.") end - return l_emp + l_reg, loss_dict + return reg.λ * l_ end - # Loss function to be called for multiple shooting - function loss_function(data, pred) + """ + Loss function to be called for multiple shooting + This seems duplicated from before, so be careful with this + """ + function _loss_multiple_shooting(data, pred) # Empirical error l_emp = 3.0 * mean(abs2.(pred .- data)) @@ -167,44 +251,18 @@ function train(data::AD, p = β.θ) end - return multiple_shoot(β.θ, data.directions, data.times, _prob, loss_function, continuity_loss, params.solver, + return multiple_shoot(β.θ, data.directions, data.times, _prob, _loss_multiple_shooting, continuity_loss, params.solver, group_size; continuity_term, sensealg=params.sensealg) end - function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where {AG <: AbstractRegularization} - - l_ = 0.0 - if reg.order==0 - l_ += quadrature(t -> norm(predict_L(t, U, θ, st))^reg.power, params.tmin, params.tmax, n_nodes) - elseif reg.order==1 - if reg.diff_mode=="AD" - throw("Method not working well.") - # Compute gradient using automatic differentiaion in the NN - # This currently doesn't run... too slow. - - # Test this with the new implementation in Lux.jl: - # https://lux.csail.mit.edu/stable/manual/nested_autodiff - elseif reg.diff_mode=="FD" - # Finite differences - ϵ = 0.1 * (params.tmax - params.tmin) / n_nodes - l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) - elseif reg.diff_mode=="CS" - # Complex step differentiation - l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t))^reg.power, params.tmin, params.tmax, n_nodes) - else - throw("Method not implemented.") - end - else - throw("Method not implemented.") - end - return reg.λ * l_ - end + ### Callback function losses = Float64[] callback = function (p, l) push!(losses, l) if length(losses) % 50 == 0 - println("Current loss after $(length(losses)) iterations: $(losses[end])") + @printf "Iteration: [%5d / %5d] \t Loss: %.9f \n" length(losses) (params.niter_ADAM+params.niter_LBFGS) losses[end] + # println("Current loss after $(length(losses)) iterations: $(losses[end])") end if params.train_initial_condition p.u0 ./= norm(p.u0) @@ -212,19 +270,48 @@ function train(data::AD, return false end + println("Loss after initalization: ", loss(β)[1]) + # Dispatch the right loss function f_loss = params.multiple_shooting ? loss_multiple_shooting : loss + # Optimization setting with AD adtype = Optimization.AutoZygote() + + """ + Pretraining to find parameters without impossing regularization + """ + if params.pretrain + losses_pretrain = Float64[] + callback_pretrain = function(p, l) + push!(losses_pretrain, l) + if length(losses_pretrain) % 100 == 0 + @printf "[Pretrain with no regularization] Iteration: [%5d / %5d] \t Loss: %.9f \n" length(losses_pretrain) (params.niter_ADAM+params.niter_LBFGS) losses_pretrain[end] + # println("[Pretrain with no regularization] Current loss after $(length(losses_pretrain)) iterations: $(losses_pretrain[end])") + end + return false + end + optf₀ = Optimization.OptimizationFunction((x, β) -> loss_empirical(x), adtype) + optprob₀ = Optimization.OptimizationProblem(optf₀, β) + res₀ = Optimization.solve(optprob₀, ADAM(), callback=callback_pretrain, maxiters=params.niter_ADAM, verbose=false) + optprob₁ = Optimization.OptimizationProblem(optf₀, res₀.u) + res₁ = Optimization.solve(optprob₁, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback_pretrain, maxiters=params.niter_LBFGS) + β = res₁.u + end + + # To do: implement this with polyoptimizaion to put ADAM and BFGS in one step. + # Maybe better to keep like this for the line search. + optf = Optimization.OptimizationFunction((x, β) -> (first ∘ f_loss)(x), adtype) optprob = Optimization.OptimizationProblem(optf, β) - res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=false) + res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=true) println("Training loss after $(length(losses)) iterations: $(losses[end])") if params.niter_LBFGS > 0 optprob2 = Optimization.OptimizationProblem(optf, res1.u) - res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) + # res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) + res2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) println("Final training loss after $(length(losses)) iterations: $(losses[end])") else res2 = res1 @@ -242,13 +329,15 @@ function train(data::AD, end # Final Fit - fit_times = collect(range(params.tmin,params.tmax, length=length(data.times))) + fit_times = collect(range(params.tmin,params.tmax, length=1000)) fit_directions, _ = predict(β_trained, T=fit_times) + fit_rotations = reduce(hcat, (t -> U([t], θ_trained, st)[1]).(fit_times)) # Recover final balance between different terms involved in the loss function to assess hyperparameter selection. _, loss_dict = loss(β_trained) pretty_table(loss_dict, sortkeys=true, header=["Loss term", "Value"]) return Results(θ=θ_trained, u0=u0_trained, U=U, st=st, - fit_times=fit_times, fit_directions=fit_directions) + fit_times=fit_times, fit_directions=fit_directions, + fit_rotations=fit_rotations, losses=losses) end diff --git a/src/types.jl b/src/types.jl index 315720e..5d5ed8c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -2,11 +2,13 @@ export SphereParameters, AbstractParameters export SphereData, AbstractData export Results, AbstractResult export Regularization, AbstractRegularization +export FiniteDifferences, ComplexStepDifferentiation, LuxNestedAD, AbstractDifferentiation abstract type AbstractParameters end abstract type AbstractData end abstract type AbstractRegularization end abstract type AbstractResult end +abstract type AbstractDifferentiation end """ Training parameters @@ -17,14 +19,15 @@ Training parameters u0::Union{Vector{F}, Nothing} ωmax::F reg::Union{Nothing, Array} - train_initial_condition::Bool - multiple_shooting::Bool + train_initial_condition::Bool = false + multiple_shooting::Bool = false niter_ADAM::I niter_LBFGS::I reltol::F abstol::F - solver::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5() + solver::OrdinaryDiffEqCore.OrdinaryDiffEqAlgorithm = Tsit5() sensealg::SciMLBase.AbstractAdjointSensitivityAlgorithm = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)) + pretrain::Bool = false end """ @@ -43,10 +46,12 @@ Final results @kwdef struct Results{F <: AbstractFloat} <: AbstractResult θ::ComponentVector u0::Vector{F} - U::Lux.Chain + U::Lux.AbstractLuxLayer st::NamedTuple fit_times::Vector{F} fit_directions::Matrix{F} + fit_rotations::Matrix{F} + losses::Vector{F} end """ @@ -57,8 +62,23 @@ Regularization information power::F # Power of the Euclidean norm λ::F # Regularization hyperparameter # AD differentiation mode used in regulatization - diff_mode::Union{Nothing, String} = nothing + diff_mode::Union{Nothing, AbstractDifferentiation} = nothing # Include this in the constructor # @assert (order == 0) || (!isnothing(diff_mode)) "Diffentiation methods needs to be provided for regularization with order larger than zero." +end + +""" +Differentiation methods +""" +@kwdef struct FiniteDifferences{F <: AbstractFloat} <: AbstractDifferentiation + ϵ::F = 1e-10 +end + +@kwdef struct ComplexStepDifferentiation{F <: AbstractFloat} <: AbstractDifferentiation + ϵ::F = 1e-10 +end + +@kwdef struct LuxNestedAD <: AbstractDifferentiation + method::Union{Nothing, String} = "ForwardDiff" end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index c632e19..47b86bc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,29 +1,38 @@ export sigmoid, sigmoid_cap export relu, relu_cap +export gelu, rbf export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation export raise_warnings export isL1reg +export convert2dict + +# Import activation function for complex extension +import Lux: relu, gelu +# import Lux: sigmoid, relu, gelu + +### Custom Activation Funtions """ sigmoid_cap(x; ω₀=1.0) Normalization of the neural network last layer """ -function sigmoid_cap(x; ω₀=1.0) +rbf(x) = exp.(-(x .^ 2)) + +sigmoid_cap(x; ω₀=1.0) = sigmoid_cap(x, ω₀) + +function sigmoid_cap(x, ω₀) min_value = - ω₀ max_value = + ω₀ return min_value + (max_value - min_value) * sigmoid(x) end +# For some reason, when I import the Lux.sigmoid function this train badly, +# increasing the value of the loss function over iterations... function sigmoid(x) return 1.0 / (1.0 + exp(-x)) -# if x > 0 -# return 1 / ( 1.0 + exp(-x) ) -# else -# return exp(x) / (1.0 + exp(x)) -# end end function sigmoid(z::Complex) @@ -38,7 +47,9 @@ end """ relu_cap(x; ω₀=1.0) """ -function relu_cap(x; ω₀=1.0) +relu_cap(x; ω₀=1.0) = relu_cap(x, ω₀) + +function relu_cap(x, ω₀) min_value = - ω₀ max_value = + ω₀ return relu_cap(x, min_value, max_value) @@ -48,6 +59,8 @@ function relu_cap(x, min_value::Float64, max_value::Float64) return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) end +### Complex Expansion Activation Functions + """ relu(x::Complex) @@ -66,10 +79,39 @@ function relu_cap(z::Complex; ω₀=1.0) # return min_value + (max_value - min_value) * relu(z - relu(z-1)) end +""" + relu_cap(z::Complex, min_value::Float64, max_value::Float64) +""" function relu_cap(z::Complex, min_value::Float64, max_value::Float64) return min_value + (max_value - min_value) * relu(z - relu(z-1)) end +""" + sigmoid(z::Complex) +""" +function sigmoid(z::Complex) + return 1.0 / ( 1.0 + exp(-z) ) + # if real(z) > 0 + # return 1 / ( 1.0 + exp(-z) ) + # else + # return exp(z) / (1.0 + exp(z)) + # end +end + +""" + gelu(x::Complex) + +Extension of the GELU activation function for complex variables. +We use the approximation using tanh() to avoid dealing with the complex error function +""" +function gelu(z::Complex) + # We use the Gelu approximation to avoid complex holomorphic error function + return 0.5 * z * (1 + tanh((sqrt(2/π))*(z + 0.044715 * z^3))) +end + + +### Spherical Utils + """ cart2sph(X::AbstractArray{<:Number}; radians::Bool=true) @@ -129,19 +171,26 @@ function Base.:(+)(X::Array{F, 2}, ϵ::N) where {F <: AbstractFloat, N <: Abstra end end +### Numerics Utils + """ quadrature_integrate Numerical integral using Gaussian quadrature """ function quadrature(f::Function, t₀, t₁, n_nodes::Int) + nodes, weigths = quadrature(t₀, t₁, n_nodes) + return dot(weigths, f.(nodes)) +end + +function quadrature(t₀, t₁, n_nodes::Int) ignore() do # Ignore AD here since FastGaussQuadrature is using mutating arrays nodes, weigths = gausslegendre(n_nodes) end nodes = (t₀+t₁)/2 .+ nodes * (t₁-t₀)/2 weigths = (t₁-t₀) / 2 * weigths - return dot(weigths, f.(nodes)) + return nodes, weigths end """ @@ -152,7 +201,7 @@ Simple central differences implementation. FiniteDifferences.jl does not work with AD so I implemented this manually. Still remains to test this with FiniteDiff.jl """ -function central_fdm(f::Function, x::Float64; ϵ=0.01) +function central_fdm(f::Function, x::Float64, ϵ::Float64) return (f(x+ϵ)-f(x-ϵ)) / (2ϵ) end @@ -161,10 +210,12 @@ end Manual implementation of complex-step differentiation """ -function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10) +function complex_step_differentiation(f::Function, x::Float64, ϵ::Float64) return imag(f(x + ϵ * im)) / ϵ end +### Other Utils + """ raise_warnings(data::AD, params::AP) @@ -188,11 +239,28 @@ end Function to check for the presence of L1 regularization in the loss function. """ -function isL1reg(regs::Vector{R}) where {R <: AbstractRegularization} +function isL1reg(regs::Union{Vector{R}, Nothing}) where {R <: AbstractRegularization} + if isnothing(regs) + return false + end for reg in regs if reg.power == 1 return true end end return false +end + +function convert2dict(data::SphereData, results::Results) + _dict = Dict() + _dict["times"] = data.times + _dict["directions"] = data.directions + _dict["kappas"] = data.kappas + _dict["u0"] = results.u0 + _dict["fit_times"] = results.fit_times + _dict["fit_directions"] = results.fit_directions + _dict["fit_rotations"] = results.fit_rotations + _dict["losses"] = results.losses + + return _dict end \ No newline at end of file diff --git a/test/constructors.jl b/test/constructors.jl index 4bd2b41..d37855d 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -1,12 +1,12 @@ function test_reg_constructor() - reg = Regularization(order=1, power=2.0, λ=0.1, diff_mode="AD") + reg = Regularization(order=1, power=2.0, λ=0.1, diff_mode=ComplexStepDifferentiation()) @test reg.order == 1 @test reg.power == 2.0 @test reg.λ == 0.1 - @test reg.diff_mode == "AD" + @test typeof(reg.diff_mode) <: AbstractDifferentiation end @@ -21,8 +21,8 @@ function test_param_constructor() niter_ADAM=1000, niter_LBFGS=300, reltol=1e6, abstol=1e-6) - reg1 = Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD") - reg2 = Regularization(order=1, power=1.0, λ=0.1, diff_mode="AD") + reg1 = Regularization(order=0, power=2.0, λ=0.1, diff_mode=FiniteDifferences()) + reg2 = Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()) params2 = SphereParameters(tmin=0.0, tmax=100., u0=[0. ,0. ,1.], @@ -37,6 +37,6 @@ function test_param_constructor() @test params.tmax == 100.0 @test params2.reg[1].order == 0 - @test params2.reg[2].diff_mode == "AD" + @test typeof(params2.reg[2].diff_mode) <: LuxNestedAD end \ No newline at end of file diff --git a/test/rotation.jl b/test/rotation.jl index 5e78b1c..43b8b55 100644 --- a/test/rotation.jl +++ b/test/rotation.jl @@ -1,5 +1,4 @@ -using LinearAlgebra, Statistics, Distributions -using OrdinaryDiffEq +using LinearAlgebra, Statistics, Distributions using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL @@ -37,7 +36,7 @@ function test_single_rotation() data = SphereData(times=times_samples, directions=X, kappas=nothing, L=nothing) - regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="FD"), + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=FiniteDifferences(1e-6)), Regularization(order=0, power=2.0, λ=0.001, diff_mode=nothing)] params = SphereParameters(tmin = tspan[1], tmax = tspan[2], diff --git a/test/runtests.jl b/test/runtests.jl index db621eb..4bbdf6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,4 +19,4 @@ end @testset "Inversion" begin test_single_rotation() -end \ No newline at end of file +end