From 8f88211c1e8be4ae53325e02533b6613d69da873 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Wed, 10 Apr 2024 01:46:48 +0000 Subject: [PATCH 01/29] Update Project with new dependencies --- Project.toml | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index a6a29c2..f4bdf00 100644 --- a/Project.toml +++ b/Project.toml @@ -7,6 +7,8 @@ version = "0.1.1" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" +FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -19,6 +21,7 @@ PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" +TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] @@ -26,7 +29,6 @@ BenchmarkTools = "1" ComponentArrays = "0.15" Distributions = "0.25" Infiltrator = "1.2" -Lux = "0.5" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" From ad54394e8b9751de95f20524715618170176876a Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Wed, 17 Apr 2024 02:26:47 +0000 Subject: [PATCH 02/29] remove FiniteDifferences from dependencies --- Project.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/Project.toml b/Project.toml index f4bdf00..6557314 100644 --- a/Project.toml +++ b/Project.toml @@ -8,7 +8,6 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" From a4790d3d4dc52d52da1ea53a06f5a8a0a0779dbf Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Wed, 17 Apr 2024 19:03:09 +0000 Subject: [PATCH 03/29] Example of APWP fit based on Jupp1987 --- Project.toml | 2 + .../APWP_Jupp1987/Jupp-etal-1987_dataset.csv | 27 ++++++++++++ examples/APWP_Jupp1987/apwp.jl | 41 +++++++++++++++++++ src/plot.jl | 4 +- src/train.jl | 24 +++++------ src/utils.jl | 33 ++++++++++++++- 6 files changed, 114 insertions(+), 17 deletions(-) create mode 100644 examples/APWP_Jupp1987/Jupp-etal-1987_dataset.csv create mode 100644 examples/APWP_Jupp1987/apwp.jl diff --git a/Project.toml b/Project.toml index 6557314..968e8a7 100644 --- a/Project.toml +++ b/Project.toml @@ -5,7 +5,9 @@ version = "0.1.1" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" +CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" +DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" diff --git a/examples/APWP_Jupp1987/Jupp-etal-1987_dataset.csv b/examples/APWP_Jupp1987/Jupp-etal-1987_dataset.csv new file mode 100644 index 0000000..322e2d4 --- /dev/null +++ b/examples/APWP_Jupp1987/Jupp-etal-1987_dataset.csv @@ -0,0 +1,27 @@ +Obs,Time,Lat,Lon +A,0.3,80.6,349.2 +B,1,86.2,123.7 +C,1,86.6,205.5 +D,14.9,83.3,294 +E,26.6,84.8,67.6 +F,32.5,81,41 +G,55.5,86,178 +H,68.9,76,147 +I,98,86,298 +J,106,70.8,15.4 +K,107.8,61.5,244.2 +L,109.6,87,49.5 +M,118,36,296.3 +N,144,86,344 +O,171.9,56.5,12 +P,172,58,38 +Q,172,45,39 +R,172,59,41 +S,189,54.2,40.2 +T,189,44.1,51.5 +U,189,53.8,42.6 +V,209.5,82,290 +W,249.5,51,206 +X,481.5,9.3,206.7 +Y,481.5,28,190 +Z,519.7,1.5,208.5 \ No newline at end of file diff --git a/examples/APWP_Jupp1987/apwp.jl b/examples/APWP_Jupp1987/apwp.jl new file mode 100644 index 0000000..df03292 --- /dev/null +++ b/examples/APWP_Jupp1987/apwp.jl @@ -0,0 +1,41 @@ +using Pkg; Pkg.activate(".") +using Revise + +using SphereUDE +using DataFrames, CSV +using Random +rng = Random.default_rng() +Random.seed!(rng, 666) + + +df = CSV.read("examples/APWP_Jupp1987/Jupp-etal-1987_dataset.csv", DataFrame; header=true) + +times = df.Time +times .+= rand(sampler(Normal(0,0.1)), length(times)) # Needs to fix this! +X = sph2cart(Matrix(df[:,["Lat","Lon"]])'; radians=false) + +data = SphereData(times=times, directions=X, kappas=nothing, L=nothing) + +# Expected maximum angular deviation in one unit of time (degrees) +Δω₀ = 1.0 +# Angular velocity +ω₀ = Δω₀ * π / 180.0 + +# regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="Finite Differences"), +# Regularization(order=0, power=2.0, λ=0.1, diff_mode="Finite Differences")] +# regs = [Regularization(order=1, power=2.0, λ=0.1, diff_mode="Finite Differences")] +regs = nothing + +params = SphereParameters(tmin=0., tmax=520., + reg=regs, + u0=[0.0, 0.0, 1.0], ωmax=ω₀, reltol=1e-7, abstol=1e-7, + niter_ADAM=1000, niter_LBFGS=400) + +results = train(data, params, rng, nothing) + +############################################################## +###################### PyCall Plots ######################### +############################################################## + +plot_sphere(data, results, mean(df.Lat), mean(df.Lon), saveas="examples/APWP_Jupp1987/plot_sphere.pdf", title="Double rotation") +plot_L(data, results, saveas="examples/APWP_Jupp1987/plot_L.pdf", title="Double rotation") \ No newline at end of file diff --git a/src/plot.jl b/src/plot.jl index e37d03b..bbf3697 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -30,7 +30,7 @@ function plot_sphere(# ax::PyCall.PyObject, plt.figure(figsize=(10,10)) ax = plt.axes(projection=ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)) - ax.coastlines() + # ax.coastlines() ax.gridlines() ax.set_global() @@ -39,7 +39,7 @@ function plot_sphere(# ax::PyCall.PyObject, X_fit_path = cart2sph(results.fit_directions, radians=false) sns.scatterplot(ax=ax, x = X_true_points[1,:], y=X_true_points[2, :], - hue = data.times, s=50, + hue = data.times, s=150, palette="viridis", transform = ccrs.PlateCarree()); diff --git a/src/train.jl b/src/train.jl index 641b72e..1b846df 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,11 +1,15 @@ export train +# For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid function get_NN(params, rng, θ_trained) # Define neural network U = Lux.Chain( Lux.Dense(1, 5, relu_cap), # explore discontinuity function for activation Lux.Dense(5, 10, relu_cap), - Lux.Dense(10, 5, relu_cap), + Lux.Dense(10, 5, relu_cap), + # Lux.Dense(1, 5, sigmoid), + # Lux.Dense(5, 10, sigmoid), + # Lux.Dense(10, 5, sigmoid), Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax)) ) θ, st = Lux.setup(rng, U) @@ -19,9 +23,6 @@ function train(data::AD, U, θ, st = get_NN(params, rng, θ_trained) - # one option is to restrict where the NN is evaluated to discrete t to - # generate piece-wise dynamics. - function ude_rotation!(du, u, p, t) # Angular momentum given by network prediction L = U([t], p, st)[1] @@ -66,17 +67,14 @@ function train(data::AD, # Create (uniform) spacing time # Δt = (params.tmax - params.tmin) / n_nodes # times_reg = collect(params.tmin:Δt:params.tmax) - # LinRange does not propagate thought the backprop step! # times_reg = collect(LinRange(params.tmin, params.tmax, n_nodes)) l_ = 0.0 if reg.order==0 l_ += quadrature(t -> norm(U([t], θ, st)[1])^reg.power, params.tmin, params.tmax, n_nodes) - # for t in times_reg - # l_ += norm(U([t], θ, st)[1])^reg.power - # end elseif reg.order==1 if reg.diff_mode=="AD" + throw("Method not working well.") # Compute gradient using automatic differentiaion in the NN # This currently doesn't run... too slow. for t in times_reg @@ -85,10 +83,10 @@ function train(data::AD, l_ += norm(grad)^reg.power end elseif reg.diff_mode=="Finite Differences" - # Compute finite differences - L_estimated = map(t -> (first ∘ U)([t], θ, st), times_reg) - dLdt = diff(L_estimated) ./ diff(times_reg) - l_ += sum(norm.(dLdt).^reg.power) + # Using FiniteDifferences break precompilation becuase of name collission + # l_ += quadrature(t -> norm(FiniteDifferences.jacobian(FiniteDifferences.central_fdm(2,1), τ -> (first ∘ U)([τ], θ, st), t)[1])^reg.power, params.tmin, params.tmax, n_nodes) + ϵ = 0.1 * (params.tmax - params.tmin) / n_nodes + l_ += quadrature(t -> norm(central_fdm(τ -> (first ∘ U)([τ], θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) else throw("Method not implemented.") end @@ -122,7 +120,7 @@ function train(data::AD, θ_trained = res2.u # Final Fit - fit_times = collect(params.tmin:0.1:params.tmax) + fit_times = collect(range(params.tmin,params.tmax, length=200)) fit_directions = predict(θ_trained, T=fit_times) return Results(θ_trained=θ_trained, U=U, st=st, diff --git a/src/utils.jl b/src/utils.jl index d441b79..9dfc1fb 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,7 @@ export sigmoid_cap, relu_cap, step_cap -export cart2sph +export cart2sph, sph2cart export AbstractNoise, FisherNoise -export quadrature +export quadrature, central_fdm # Normalization of the NN. Ideally we want to do this with L2 norm . @@ -37,6 +37,23 @@ function cart2sph(X::AbstractArray{<:Number}; radians::Bool=true) return Y end + +""" + sph2cart(X::AbstractArray{<:Number}; radians::Bool=true) + +Convert spherical coordinates to cartesian +""" +function sph2cart(X::AbstractArray{<:Number}; radians::Bool=true) + @assert size(X)[1] == 2 "Input array must have two rows corresponding to Latitude and Longitude." + if !radians + X *= π / 180. + end + Y = mapslices(x -> [cos(x[1])*cos(x[2]), + cos(x[1])*sin(x[2]), + sin(x[1])] , X, dims=1) + return Y +end + """ Add Fisher noise to matrix of three dimensional unit vectors @@ -78,4 +95,16 @@ function quadrature(f::Function, t₀, t₁, n_nodes::Int) nodes = (t₀+t₁)/2 .+ nodes * (t₁-t₀)/2 weigths = (t₁-t₀) / 2 * weigths return dot(weigths, f.(nodes)) +end + +""" + central_fdm(f::Function, x::Float64; ϵ=0.01) + +Simple central differences implementation. + +FiniteDifferences.jl does not work with AD so I implemented this manually. +Still remains to test this with FiniteDiff.jl +""" +function central_fdm(f::Function, x::Float64; ϵ=0.01) + return (f(x+ϵ)-f(x-ϵ)) / (2ϵ) end \ No newline at end of file From 29b8c78450fa93d3a586db334734f75237d6d8c4 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 18 Apr 2024 19:46:37 -0700 Subject: [PATCH 04/29] add complex-step method --- src/train.jl | 14 +++++++++++--- src/utils.jl | 11 ++++++++++- 2 files changed, 21 insertions(+), 4 deletions(-) diff --git a/src/train.jl b/src/train.jl index 1b846df..647c33d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -29,6 +29,12 @@ function train(data::AD, du .= cross(L, u) end + # function ude_rotation!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t) + # # Angular momentum given by network prediction + # L = U([t], p, st)[1] + # du .= cross(L, u) + # end + prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ) function predict(θ::ComponentVector; u0=params.u0, T=data.times) @@ -82,11 +88,13 @@ function train(data::AD, grad = Zygote.jacobian(first ∘ U, [t], θ, st)[1] l_ += norm(grad)^reg.power end - elseif reg.diff_mode=="Finite Differences" - # Using FiniteDifferences break precompilation becuase of name collission - # l_ += quadrature(t -> norm(FiniteDifferences.jacobian(FiniteDifferences.central_fdm(2,1), τ -> (first ∘ U)([τ], θ, st), t)[1])^reg.power, params.tmin, params.tmax, n_nodes) + elseif reg.diff_mode=="FD" + # Finite differences ϵ = 0.1 * (params.tmax - params.tmin) / n_nodes l_ += quadrature(t -> norm(central_fdm(τ -> (first ∘ U)([τ], θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + elseif reg.diff_mode=="CS" + # Complex step differentiation + l_ += quadrature(t -> norm(complex_step_differentiation(τ -> (first ∘ U)([τ], θ, st), t))^reg.power, params.tmin, params.tmax, n_nodes) else throw("Method not implemented.") end diff --git a/src/utils.jl b/src/utils.jl index 9dfc1fb..9ad00c9 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,7 +1,7 @@ export sigmoid_cap, relu_cap, step_cap export cart2sph, sph2cart export AbstractNoise, FisherNoise -export quadrature, central_fdm +export quadrature, central_fdm, complex_step_differentiation # Normalization of the NN. Ideally we want to do this with L2 norm . @@ -107,4 +107,13 @@ Still remains to test this with FiniteDiff.jl """ function central_fdm(f::Function, x::Float64; ϵ=0.01) return (f(x+ϵ)-f(x-ϵ)) / (2ϵ) +end + +""" + complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10) + +Manual implementation of comple-step differentiation +""" +function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10) + return imag(f(x + ϵ * im)) / ϵ end \ No newline at end of file From a3852d16258370884346cab348700b6b104262a5 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 18 Apr 2024 23:10:00 -0700 Subject: [PATCH 05/29] Double differentiation working with complex-step method --- examples/double_rotation/double_rotation.jl | 6 ++-- src/SphereUDE.jl | 5 +++- src/train.jl | 11 +++---- src/utils.jl | 32 ++++++++++++++++++++- 4 files changed, 42 insertions(+), 12 deletions(-) diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 1e0e5b8..d23e5b0 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -64,9 +64,9 @@ X_true = X_noiseless + FisherNoise(kappa=200.) data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true) -# regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="Finite Differences"), - # Regularization(order=0, power=2.0, λ=0.1, diff_mode="Finite Differences")] -regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD")] +regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="CS"), + Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD")] +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="CS")] params = SphereParameters(tmin=tspan[1], tmax=tspan[2], reg=regs, diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index ac216c0..ebce6ae 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -14,11 +14,14 @@ using Optimization, OptimizationOptimisers, OptimizationOptimJL using ComponentArrays: ComponentVector using PyPlot, PyCall +# Testing double-differentiation +# using BatchedRoutines + # Debugging using Infiltrator -include("utils.jl") include("types.jl") +include("utils.jl") include("train.jl") include("plot.jl") diff --git a/src/train.jl b/src/train.jl index 647c33d..d0f561a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -21,6 +21,9 @@ function train(data::AD, rng, θ_trained=[]) where{AD <: AbstractData, AP <: AbstractParameters} + # Raise warnings + raise_warnings(data::AD, params::AP) + U, θ, st = get_NN(params, rng, θ_trained) function ude_rotation!(du, u, p, t) @@ -29,12 +32,6 @@ function train(data::AD, du .= cross(L, u) end - # function ude_rotation!(du::Array{Complex{Float64}}, u::Array{Complex{Float64}}, p, t) - # # Angular momentum given by network prediction - # L = U([t], p, st)[1] - # du .= cross(L, u) - # end - prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ) function predict(θ::ComponentVector; u0=params.u0, T=data.times) @@ -107,7 +104,7 @@ function train(data::AD, losses = Float64[] callback = function (p, l) push!(losses, l) - if length(losses) % 200 == 0 + if length(losses) % 100 == 0 println("Current loss after $(length(losses)) iterations: $(losses[end])") end return false diff --git a/src/utils.jl b/src/utils.jl index 9ad00c9..c8f2eed 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -2,6 +2,7 @@ export sigmoid_cap, relu_cap, step_cap export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation +export raise_warnings # Normalization of the NN. Ideally we want to do this with L2 norm . @@ -14,6 +15,10 @@ function sigmoid_cap(x; ω₀=1.0) return min_value + (max_value - min_value) / ( 1.0 + exp(-x) ) end +function sigmoid(x::Complex) + return 1 / ( 1.0 + exp(-x) ) +end + """ relu_cap(x; ω₀=1.0) """ @@ -23,6 +28,12 @@ function relu_cap(x; ω₀=1.0) return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) end +function relu_cap(x::Complex; ω₀=1.0) + min_value = - ω₀ + max_value = + ω₀ + return min_value + (max_value - min_value) * max(0.0, min(real(x), 1.0)) + (max_value - min_value) * max(0.0, min(imag(x), 1.0)) * im +end + """ cart2sph(X::AbstractArray{<:Number}; radians::Bool=true) @@ -92,7 +103,7 @@ function quadrature(f::Function, t₀, t₁, n_nodes::Int) # Ignore AD here since FastGaussQuadrature is using mutating arrays nodes, weigths = gausslegendre(n_nodes) end - nodes = (t₀+t₁)/2 .+ nodes * (t₁-t₀)/2 + nodes = (t₀+t₁)/2 .+ nodes * (t₁-t₀)/2 weigths = (t₁-t₀) / 2 * weigths return dot(weigths, f.(nodes)) end @@ -116,4 +127,23 @@ Manual implementation of comple-step differentiation """ function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10) return imag(f(x + ϵ * im)) / ϵ +end + +""" + raise_warnings(data::AD, params::AP) + +Raise warnings. +""" +function raise_warnings(data::SphereData, params::SphereParameters) + if length(unique(data.times)) < length(data.times) + @warn "[SphereUDE] Timeseries includes duplicated times. \n This can produce unexpected errors." + end + if !isnothing(params.reg) + for reg in params.reg + if reg.diff_mode=="CS" + @warn "[SphereUDE] Complex-step differentiation inside the loss function \n This just work for cases where the activation function of the neural network admits complex numbers \n Change predefined activation functions to be complex numbers." + end + end + end + nothing end \ No newline at end of file From cc5ce31b3d82d1109e44e7d397fb4727066c494e Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 19 Apr 2024 12:53:51 -0700 Subject: [PATCH 06/29] Initial condition u0 fitting implemented --- examples/double_rotation/double_rotation.jl | 13 +++-- src/plot.jl | 2 +- src/train.jl | 56 +++++++++++++-------- src/types.jl | 3 +- 4 files changed, 47 insertions(+), 27 deletions(-) diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index d23e5b0..04ddc3f 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -64,16 +64,19 @@ X_true = X_noiseless + FisherNoise(kappa=200.) data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true) -regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="CS"), - Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD")] -# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="CS")] +# regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="CS"), +# Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD")] +regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing), + Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")] +# regs = nothing params = SphereParameters(tmin=tspan[1], tmax=tspan[2], reg=regs, u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol, - niter_ADAM=1000, niter_LBFGS=600) + niter_ADAM=1000, niter_LBFGS=500) -results = train(data, params, rng, nothing) +# results = train(data, params, rng, nothing) +results = train(data, params, rng, nothing; train_initial_condition=true) ############################################################## ###################### PyCall Plots ######################### diff --git a/src/plot.jl b/src/plot.jl index bbf3697..8fafd15 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -69,7 +69,7 @@ function plot_L(data::AbstractData, fig, ax = plt.subplots(figsize=(10,5)) times_smooth = collect(LinRange(results.fit_times[begin], results.fit_times[end], 1000)) - Ls = reduce(hcat, (t -> results.U([t], results.θ_trained, results.st)[1]).(times_smooth)) + Ls = reduce(hcat, (t -> results.U([t], results.θ, results.st)[1]).(times_smooth)) angular_velocity = mapslices(x -> norm(x), Ls, dims=1)[1,:] diff --git a/src/train.jl b/src/train.jl index d0f561a..60a837a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -19,32 +19,46 @@ end function train(data::AD, params::AP, rng, - θ_trained=[]) where{AD <: AbstractData, AP <: AbstractParameters} + θ_trained=[]; + train_initial_condition::Bool=false) where {AD <: AbstractData, AP <: AbstractParameters} # Raise warnings raise_warnings(data::AD, params::AP) U, θ, st = get_NN(params, rng, θ_trained) + # Set component vector for Optimization + if train_initial_condition + β = ComponentVector{Float64}(θ=θ, u0=params.u0) + else + β = ComponentVector{Float64}(θ=θ) + end + function ude_rotation!(du, u, p, t) # Angular momentum given by network prediction L = U([t], p, st)[1] du .= cross(L, u) end - prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], θ) + prob_nn = ODEProblem(ude_rotation!, params.u0, [params.tmin, params.tmax], β.θ) - function predict(θ::ComponentVector; u0=params.u0, T=data.times) - _prob = remake(prob_nn, u0=u0, - tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), - p = θ) + function predict(β::ComponentVector; T=data.times) + if train_initial_condition + _prob = remake(prob_nn, u0=β.u0 / norm(β.u0), + tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), + p = β.θ) + else + _prob = remake(prob_nn, u0=params.u0, + tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), + p = β.θ) + end Array(solve(_prob, params.solver, saveat=T, abstol=params.abstol, reltol=params.reltol, sensealg=params.sensealg)) end - function loss(θ::ComponentVector) - u_ = predict(θ) + function loss(β::ComponentVector) + u_ = predict(β) # Empirical error # l_emp = mean(abs2, u_ .- data.directions) if isnothing(data.kappas) @@ -58,8 +72,7 @@ function train(data::AD, if !isnothing(params.reg) # for (order, power, λ) in params.reg for reg in params.reg - # l_reg += reg.λ * regularization(θ; order=reg.order, power=reg.power) - l_reg += regularization(θ, reg) + l_reg += regularization(β.θ, reg) end end return l_emp + l_reg @@ -67,11 +80,6 @@ function train(data::AD, function regularization(θ::ComponentVector, reg::AbstractRegularization; n_nodes=100) - # Create (uniform) spacing time - # Δt = (params.tmax - params.tmin) / n_nodes - # times_reg = collect(params.tmin:Δt:params.tmax) - # LinRange does not propagate thought the backprop step! - # times_reg = collect(LinRange(params.tmin, params.tmax, n_nodes)) l_ = 0.0 if reg.order==0 l_ += quadrature(t -> norm(U([t], θ, st)[1])^reg.power, params.tmin, params.tmax, n_nodes) @@ -111,8 +119,8 @@ function train(data::AD, end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((x, θ) -> loss(x), adtype) - optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(θ)) + optf = Optimization.OptimizationFunction((x, β) -> loss(x), adtype) + optprob = Optimization.OptimizationProblem(optf, β) res1 = Optimization.solve(optprob, ADAM(0.002), callback=callback, maxiters=params.niter_ADAM) println("Training loss after $(length(losses)) iterations: $(losses[end])") @@ -122,12 +130,20 @@ function train(data::AD, println("Final training loss after $(length(losses)) iterations: $(losses[end])") # Optimized NN parameters - θ_trained = res2.u + β_trained = res2.u + θ_trained = β_trained.θ + + # Optimized initial condition + if train_initial_condition + u0_trained = Array(β_trained.u0) # β.u0 is a view type + else + u0_trained = params.u0 + end # Final Fit fit_times = collect(range(params.tmin,params.tmax, length=200)) - fit_directions = predict(θ_trained, T=fit_times) + fit_directions = predict(β_trained, T=fit_times) - return Results(θ_trained=θ_trained, U=U, st=st, + return Results(θ=θ_trained, u0=u0_trained, U=U, st=st, fit_times=fit_times, fit_directions=fit_directions) end diff --git a/src/types.jl b/src/types.jl index b1c3001..a8c6237 100644 --- a/src/types.jl +++ b/src/types.jl @@ -39,7 +39,8 @@ end Final results """ @kwdef struct Results{F <: AbstractFloat} <: AbstractResult - θ_trained::ComponentVector + θ::ComponentVector + u0::Vector{F} U::Lux.Chain st::NamedTuple fit_times::Vector{F} From 92f00d6cc70ccd737f43d9cc21ab48a6176211a1 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 19 Apr 2024 17:48:50 -0700 Subject: [PATCH 07/29] Projected gradient descent working for u0. Some more tests --- src/plot.jl | 8 +++++--- src/train.jl | 14 +++++++++----- src/utils.jl | 6 +++--- test/runtests.jl | 5 ++++- test/utils.jl | 22 +++++++++++++++++++--- 5 files changed, 40 insertions(+), 15 deletions(-) diff --git a/src/plot.jl b/src/plot.jl index 8fafd15..7deb758 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -38,14 +38,16 @@ function plot_sphere(# ax::PyCall.PyObject, # X_true_path = cart2sph(X_path, radians=false) X_fit_path = cart2sph(results.fit_directions, radians=false) - sns.scatterplot(ax=ax, x = X_true_points[1,:], y=X_true_points[2, :], + # Plots in Python follow the long, lat ordering + + sns.scatterplot(ax=ax, x = X_true_points[2,:], y=X_true_points[1, :], hue = data.times, s=150, palette="viridis", transform = ccrs.PlateCarree()); for i in 1:(length(results.fit_times)-1) - plt.plot([X_fit_path[1,i], X_fit_path[1,i+1]], - [X_fit_path[2,i], X_fit_path[2,i+1]], + plt.plot([X_fit_path[2,i], X_fit_path[2,i+1]], + [X_fit_path[1,i], X_fit_path[1,i+1]], linewidth=2, color="black",#cmap(norm(results.fit_times[i])), transform = ccrs.Geodetic()) end diff --git a/src/train.jl b/src/train.jl index 60a837a..2587e1d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -44,7 +44,7 @@ function train(data::AD, function predict(β::ComponentVector; T=data.times) if train_initial_condition - _prob = remake(prob_nn, u0=β.u0 / norm(β.u0), + _prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforced the norm=1 condition again here tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), p = β.θ) else @@ -60,12 +60,13 @@ function train(data::AD, function loss(β::ComponentVector) u_ = predict(β) # Empirical error - # l_emp = mean(abs2, u_ .- data.directions) if isnothing(data.kappas) + l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) # The 3 is needed since the mean is computen on a 3xN matrix - l_emp = 1 - 3.0 * mean(u_ .* data.directions) + # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else - l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) + # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) + l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) end # Regularization l_reg = 0.0 @@ -112,9 +113,12 @@ function train(data::AD, losses = Float64[] callback = function (p, l) push!(losses, l) - if length(losses) % 100 == 0 + if length(losses) % 10 == 0 println("Current loss after $(length(losses)) iterations: $(losses[end])") end + if train_initial_condition + p.u0 ./= norm(p.u0) + end return false end diff --git a/src/utils.jl b/src/utils.jl index c8f2eed..4ca6894 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,4 +1,4 @@ -export sigmoid_cap, relu_cap, step_cap +export sigmoid_cap, sigmoid, relu_cap, step_cap export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation @@ -41,7 +41,7 @@ Convert cartesian coordinates to spherical """ function cart2sph(X::AbstractArray{<:Number}; radians::Bool=true) @assert size(X)[1] == 3 "Input array must have three rows." - Y = mapslices(x -> [angle(x[1] + x[2]*im), asin(x[3])] , X, dims=1) + Y = mapslices(x -> [asin(x[3]), angle(x[1] + x[2]*im)] , X, dims=1) if !radians Y *= 180. / π end @@ -61,7 +61,7 @@ function sph2cart(X::AbstractArray{<:Number}; radians::Bool=true) end Y = mapslices(x -> [cos(x[1])*cos(x[2]), cos(x[1])*sin(x[2]), - sin(x[1])] , X, dims=1) + sin(x[1])], X, dims=1) return Y end diff --git a/test/runtests.jl b/test/runtests.jl index f0b11a1..f33bf74 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,5 +1,6 @@ using SphereUDE using Test +using Lux include("constructors.jl") include("utils.jl") @@ -11,5 +12,7 @@ include("utils.jl") end @testset "Utils" begin - test_cart2sph() + test_coordinate() + test_complex_activation() + test_integration() end diff --git a/test/utils.jl b/test/utils.jl index e12e0e9..172ba3b 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -1,6 +1,22 @@ -function test_cart2sph() - X₀ = [1 0 0 √2; 0 1 0 √2; 0 0 1 0] - Y₀ = [0.0 90.0 0.0 45.0; 0.0 0.0 90.0 0.0] +function test_coordinate() + X₀ = [1 0 0 1/√2; 0 0 1 1/√2; 0 1 0 0] + Y₀ = [0.0 90.0 0.0 0.0; 0.0 0.0 90.0 45.0] @test all(isapprox.(Y₀, cart2sph(X₀, radians=false), atol=1e-6)) + @test all(isapprox.(Y₀ * π / 180., cart2sph(X₀, radians=true), atol=1e-6)) + @test all(isapprox.(X₀, sph2cart(Y₀, radians=false), atol=1e-6)) + @test all(isapprox.(X₀, sph2cart(Y₀ * π / 180., radians=true), atol=1e-6)) +end + +function test_complex_activation() + pure_real = 1.0 + pure_complex = 1.0 + 1.0 * im + @test isapprox(Lux.sigmoid(pure_real), 0.7310585786300049, atol=1e-6) + @test isapprox(imag(SphereUDE.sigmoid(pure_complex)), 0.2019482276580129, atol=1e-6) +end + +function test_integration() + @test isapprox(quadrature(x->1, 0.0, 1.0, 100), 1.0, rtol=1e-6) + @test isapprox(quadrature(x->x, -1.0, 1.0, 100), 0.0, atol=1e-6) + @test isapprox(quadrature(x->x^2, -1.0, 1.0, 100), 2/3., rtol=1e-6) end \ No newline at end of file From c69a0596f25ae3a3089b4d0ee42910d06e303c88 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Sun, 28 Apr 2024 19:34:32 -0700 Subject: [PATCH 08/29] Testing activation functions with complex-step --- Project.toml | 1 + examples/double_rotation/double_rotation.jl | 15 ++++---- src/train.jl | 30 +++++++++++---- src/utils.jl | 42 +++++++++++++++++---- 4 files changed, 66 insertions(+), 22 deletions(-) diff --git a/Project.toml b/Project.toml index 968e8a7..e59eec6 100644 --- a/Project.toml +++ b/Project.toml @@ -8,6 +8,7 @@ BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b" ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66" DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" +DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 04ddc3f..f927deb 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -56,7 +56,7 @@ true_sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samp # Add Fisher noise to true solution X_noiseless = Array(true_sol) -X_true = X_noiseless + FisherNoise(kappa=200.) +X_true = X_noiseless + FisherNoise(kappa=50.) ############################################################## ####################### Training ########################### @@ -64,19 +64,18 @@ X_true = X_noiseless + FisherNoise(kappa=200.) data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true) -# regs = [Regularization(order=1, power=1.0, λ=0.001, diff_mode="CS"), -# Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD")] -regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing), - Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")] +regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"), + Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)] +# regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)] + # Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")] # regs = nothing params = SphereParameters(tmin=tspan[1], tmax=tspan[2], reg=regs, u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol, - niter_ADAM=1000, niter_LBFGS=500) + niter_ADAM=1000, niter_LBFGS=800) -# results = train(data, params, rng, nothing) -results = train(data, params, rng, nothing; train_initial_condition=true) +results = train(data, params, rng, nothing; train_initial_condition=false) ############################################################## ###################### PyCall Plots ######################### diff --git a/src/train.jl b/src/train.jl index 2587e1d..aec7ee7 100644 --- a/src/train.jl +++ b/src/train.jl @@ -4,18 +4,29 @@ export train function get_NN(params, rng, θ_trained) # Define neural network U = Lux.Chain( - Lux.Dense(1, 5, relu_cap), # explore discontinuity function for activation - Lux.Dense(5, 10, relu_cap), - Lux.Dense(10, 5, relu_cap), - # Lux.Dense(1, 5, sigmoid), - # Lux.Dense(5, 10, sigmoid), - # Lux.Dense(10, 5, sigmoid), + Lux.Dense(1, 5, sigmoid), # explore discontinuity function for activation + Lux.Dense(5, 10, sigmoid), + Lux.Dense(10, 5, sigmoid), + # Lux.Dense(1, 5, relu_cap), + # Lux.Dense(5, 10, relu_cap), + # Lux.Dense(10, 5, relu_cap), Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax)) ) θ, st = Lux.setup(rng, U) return U, θ, st end +""" + predict_L(t, NN, θ, st) + +Predict value of rotation given by L given by the neural network. + +% To do: replace all the calls in U by predict_L +""" +function predict_L(t, NN, θ, st) + return NN([t], θ, st)[1] +end + function train(data::AD, params::AP, rng, @@ -37,6 +48,7 @@ function train(data::AD, function ude_rotation!(du, u, p, t) # Angular momentum given by network prediction L = U([t], p, st)[1] + # L = predict_L(t, U, p, st) du .= cross(L, u) end @@ -65,8 +77,8 @@ function train(data::AD, # The 3 is needed since the mean is computen on a 3xN matrix # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else - # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) + # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) end # Regularization l_reg = 0.0 @@ -89,6 +101,10 @@ function train(data::AD, throw("Method not working well.") # Compute gradient using automatic differentiaion in the NN # This currently doesn't run... too slow. + + # Test this with the new implementation in Lux.jl: + # https://lux.csail.mit.edu/stable/manual/nested_autodiff + for t in times_reg # Try ReverseDiff grad = Zygote.jacobian(first ∘ U, [t], θ, st)[1] diff --git a/src/utils.jl b/src/utils.jl index 4ca6894..b8cb4f5 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,22 +1,38 @@ -export sigmoid_cap, sigmoid, relu_cap, step_cap +export sigmoid, sigmoid_cap +export relu, relu_cap export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation export raise_warnings -# Normalization of the NN. Ideally we want to do this with L2 norm . """ sigmoid_cap(x; ω₀=1.0) + +Normalization of the neural network last layer """ function sigmoid_cap(x; ω₀=1.0) min_value = - ω₀ max_value = + ω₀ - return min_value + (max_value - min_value) / ( 1.0 + exp(-x) ) + return min_value + (max_value - min_value) * sigmoid(x) +end + +function sigmoid(x) + return 1.0 / (1.0 + exp(-x)) +# if x > 0 +# return 1 / ( 1.0 + exp(-x) ) +# else +# return exp(x) / (1.0 + exp(x)) +# end end -function sigmoid(x::Complex) - return 1 / ( 1.0 + exp(-x) ) +function sigmoid(z::Complex) + return 1.0 / ( 1.0 + exp(-z) ) + # if real(z) > 0 + # return 1 / ( 1.0 + exp(-z) ) + # else + # return exp(z) / (1.0 + exp(z)) + # end end """ @@ -28,10 +44,22 @@ function relu_cap(x; ω₀=1.0) return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) end -function relu_cap(x::Complex; ω₀=1.0) + +""" + relu(x::Complex) + +Extension of ReLU function to complex numbers based on the complex cardioid introduced in +Virtue et al. (2017), "Better than Real: Complex-valued Neural Nets for MRI Fingerprinting". +This function is equivalent to relu when x is real (and hence angle(x)=0 or angle(x)=π). +""" +function relu(z::Complex) + return 0.5 * (1 + cos(angle(z))) * z +end + +function relu_cap(z::Complex; ω₀=1.0) min_value = - ω₀ max_value = + ω₀ - return min_value + (max_value - min_value) * max(0.0, min(real(x), 1.0)) + (max_value - min_value) * max(0.0, min(imag(x), 1.0)) * im + return min_value + (max_value - min_value) * relu(z - relu(z-1)) end """ From 3e3016070ba30248e6061b45aaf0d835788d390a Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Mon, 29 Apr 2024 19:39:19 -0700 Subject: [PATCH 09/29] predict function, return multiple losses --- Project.toml | 1 + examples/double_rotation/double_rotation.jl | 16 +++-- src/SphereUDE.jl | 1 + src/train.jl | 67 +++++++++++++-------- src/utils.jl | 14 ++--- 5 files changed, 63 insertions(+), 36 deletions(-) diff --git a/Project.toml b/Project.toml index e59eec6..2edae2f 100644 --- a/Project.toml +++ b/Project.toml @@ -18,6 +18,7 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index f927deb..00bca5a 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -17,10 +17,12 @@ using Random rng = Random.default_rng() Random.seed!(rng, 666) +function run() + # Total time simulation tspan = [0, 160.0] # Number of sample points -N_samples = 50 +N_samples = 100 # Times where we sample points times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples)) @@ -35,8 +37,8 @@ L0 = ω₀ .* [1.0, 0.0, 0.0] L1 = 0.5ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)] # Solver tolerances -reltol = 1e-7 -abstol = 1e-7 +reltol = 1e-12 +abstol = 1e-12 function L_true(t::Float64; τ₀=τ₀, p=[L0, L1]) if t < τ₀ @@ -73,7 +75,8 @@ regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"), params = SphereParameters(tmin=tspan[1], tmax=tspan[2], reg=regs, u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol, - niter_ADAM=1000, niter_LBFGS=800) + niter_ADAM=1000, niter_LBFGS=600, + sensealg=GaussAdjoint(autojacvec=ReverseDiffVJP(true))) results = train(data, params, rng, nothing; train_initial_condition=false) @@ -82,4 +85,7 @@ results = train(data, params, rng, nothing; train_initial_condition=false) ############################################################## plot_sphere(data, results, -20., 150., saveas="examples/double_rotation/plot_sphere.pdf", title="Double rotation") -plot_L(data, results, saveas="examples/double_rotation/plot_L.pdf", title="Double rotation") \ No newline at end of file +plot_L(data, results, saveas="examples/double_rotation/plot_L.pdf", title="Double rotation") + +end +run() \ No newline at end of file diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index ebce6ae..746d615 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -13,6 +13,7 @@ using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL using ComponentArrays: ComponentVector using PyPlot, PyCall +using PrettyTables # Testing double-differentiation # using BatchedRoutines diff --git a/src/train.jl b/src/train.jl index aec7ee7..ba946c2 100644 --- a/src/train.jl +++ b/src/train.jl @@ -20,13 +20,16 @@ end predict_L(t, NN, θ, st) Predict value of rotation given by L given by the neural network. - -% To do: replace all the calls in U by predict_L """ function predict_L(t, NN, θ, st) return NN([t], θ, st)[1] end +""" + train() + +Training function. +""" function train(data::AD, params::AP, rng, @@ -47,8 +50,7 @@ function train(data::AD, function ude_rotation!(du, u, p, t) # Angular momentum given by network prediction - L = U([t], p, st)[1] - # L = predict_L(t, U, p, st) + L = predict_L(t, U, p, st) du .= cross(L, u) end @@ -64,13 +66,24 @@ function train(data::AD, tspan=(min(T[1], params.tmin), max(T[end], params.tmax)), p = β.θ) end - Array(solve(_prob, params.solver, saveat=T, + sol = solve(_prob, params.solver, saveat=T, abstol=params.abstol, reltol=params.reltol, - sensealg=params.sensealg)) + sensealg=params.sensealg) + return Array(sol), sol.retcode end function loss(β::ComponentVector) - u_ = predict(β) + u_, retcode = predict(β) + + # If numerical integration fails or bad choice of parameter, return infinity + if retcode != :Success + @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter." + return Inf + end + + # Record the value of each individual loss to the total loss function for hyperparameter selection. + loss_dict = Dict() + # Empirical error if isnothing(data.kappas) l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) @@ -80,22 +93,26 @@ function train(data::AD, l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) end + loss_dict["Empirical"] = l_emp + # Regularization l_reg = 0.0 if !isnothing(params.reg) # for (order, power, λ) in params.reg for reg in params.reg - l_reg += regularization(β.θ, reg) + reg₀ = regularization(β.θ, reg) + l_reg += reg₀ + loss_dict["Regularization (order=$(reg.order), power=$(reg.power)"] = reg₀ end end - return l_emp + l_reg + return l_emp + l_reg, loss_dict end function regularization(θ::ComponentVector, reg::AbstractRegularization; n_nodes=100) l_ = 0.0 if reg.order==0 - l_ += quadrature(t -> norm(U([t], θ, st)[1])^reg.power, params.tmin, params.tmax, n_nodes) + l_ += quadrature(t -> norm(predict_L(t, U, θ, st))^reg.power, params.tmin, params.tmax, n_nodes) elseif reg.order==1 if reg.diff_mode=="AD" throw("Method not working well.") @@ -104,19 +121,13 @@ function train(data::AD, # Test this with the new implementation in Lux.jl: # https://lux.csail.mit.edu/stable/manual/nested_autodiff - - for t in times_reg - # Try ReverseDiff - grad = Zygote.jacobian(first ∘ U, [t], θ, st)[1] - l_ += norm(grad)^reg.power - end elseif reg.diff_mode=="FD" # Finite differences ϵ = 0.1 * (params.tmax - params.tmin) / n_nodes - l_ += quadrature(t -> norm(central_fdm(τ -> (first ∘ U)([τ], θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) elseif reg.diff_mode=="CS" # Complex step differentiation - l_ += quadrature(t -> norm(complex_step_differentiation(τ -> (first ∘ U)([τ], θ, st), t))^reg.power, params.tmin, params.tmax, n_nodes) + l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t))^reg.power, params.tmin, params.tmax, n_nodes) else throw("Method not implemented.") end @@ -129,7 +140,7 @@ function train(data::AD, losses = Float64[] callback = function (p, l) push!(losses, l) - if length(losses) % 10 == 0 + if length(losses) % 50 == 0 println("Current loss after $(length(losses)) iterations: $(losses[end])") end if train_initial_condition @@ -139,15 +150,19 @@ function train(data::AD, end adtype = Optimization.AutoZygote() - optf = Optimization.OptimizationFunction((x, β) -> loss(x), adtype) + optf = Optimization.OptimizationFunction((x, β) -> (first ∘ loss)(x), adtype) optprob = Optimization.OptimizationProblem(optf, β) res1 = Optimization.solve(optprob, ADAM(0.002), callback=callback, maxiters=params.niter_ADAM) println("Training loss after $(length(losses)) iterations: $(losses[end])") - optprob2 = Optimization.OptimizationProblem(optf, res1.u) - res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) - println("Final training loss after $(length(losses)) iterations: $(losses[end])") + if params.niter_LBFGS > 0 + optprob2 = Optimization.OptimizationProblem(optf, res1.u) + res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) + println("Final training loss after $(length(losses)) iterations: $(losses[end])") + else + res2 = res1 + end # Optimized NN parameters β_trained = res2.u @@ -162,7 +177,11 @@ function train(data::AD, # Final Fit fit_times = collect(range(params.tmin,params.tmax, length=200)) - fit_directions = predict(β_trained, T=fit_times) + fit_directions, _ = predict(β_trained, T=fit_times) + + # Recover final balance between different terms involved in the loss function to assess hyperparameter selection. + _, loss_dict = loss(β_trained) + pretty_table(loss_dict, sortkeys=true, header=["Loss term", "Value"]) return Results(θ=θ_trained, u0=u0_trained, U=U, st=st, fit_times=fit_times, fit_directions=fit_directions) diff --git a/src/utils.jl b/src/utils.jl index b8cb4f5..b7dfb54 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -166,12 +166,12 @@ function raise_warnings(data::SphereData, params::SphereParameters) if length(unique(data.times)) < length(data.times) @warn "[SphereUDE] Timeseries includes duplicated times. \n This can produce unexpected errors." end - if !isnothing(params.reg) - for reg in params.reg - if reg.diff_mode=="CS" - @warn "[SphereUDE] Complex-step differentiation inside the loss function \n This just work for cases where the activation function of the neural network admits complex numbers \n Change predefined activation functions to be complex numbers." - end - end - end + # if !isnothing(params.reg) + # for reg in params.reg + # if reg.diff_mode=="CS" + # @warn "[SphereUDE] Complex-step differentiation inside the loss function \n This just work for cases where the activation function of the neural network admits complex numbers \n Change predefined activation functions to be complex numbers." + # end + # end + # end nothing end \ No newline at end of file From 10ae8cd995badca8889a2d20685cffbec2d90069 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Mon, 29 Apr 2024 19:43:13 -0700 Subject: [PATCH 10/29] Co-authored-by: Jordi Bolibar --- examples/double_rotation/double_rotation.jl | 1 - src/utils.jl | 5 ----- 2 files changed, 6 deletions(-) diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 6510fa9..00bca5a 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -80,7 +80,6 @@ params = SphereParameters(tmin=tspan[1], tmax=tspan[2], results = train(data, params, rng, nothing; train_initial_condition=false) - ############################################################## ###################### PyCall Plots ######################### ############################################################## diff --git a/src/utils.jl b/src/utils.jl index 4c17b43..b7dfb54 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -35,10 +35,6 @@ function sigmoid(z::Complex) # end end -function sigmoid(x::Complex) - return 1 / ( 1.0 + exp(-x) ) -end - """ relu_cap(x; ω₀=1.0) """ @@ -64,7 +60,6 @@ function relu_cap(z::Complex; ω₀=1.0) min_value = - ω₀ max_value = + ω₀ return min_value + (max_value - min_value) * relu(z - relu(z-1)) - end """ From 3d5269180a9388efb2d01a342e548a758ceba7f8 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 2 May 2024 19:44:57 -0700 Subject: [PATCH 11/29] Multiple shooting working once sensealg specified --- src/train.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/train.jl b/src/train.jl index e819b28..863d869 100644 --- a/src/train.jl +++ b/src/train.jl @@ -85,8 +85,8 @@ function train(data::AD, # Empirical error if isnothing(data.kappas) - l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) # The 3 is needed since the mean is computen on a 3xN matrix + l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) @@ -112,13 +112,10 @@ function train(data::AD, # Empirical error l_emp = 3.0 * mean(abs2.(pred .- data)) - # The 3 is needed since the mean is computen on a 3xN matrix - # l_emp = 1 - 3.0 * mean(u_ .* data.directions) # Regularization l_reg = 0.0 if !isnothing(params.reg) - # for (order, power, λ) in params.reg for reg in params.reg reg₀ = regularization(β.θ, reg) l_reg += reg₀ @@ -129,11 +126,19 @@ function train(data::AD, end # Define parameters for Multiple Shooting - group_size = 50 + group_size = 10 continuity_term = 100 - ps = ComponentArray(θ) - pd, pax = getdata(ps), getaxes(ps) + ps = ComponentArray(θ) # are these necesary? + pd, pax = getdata(ps), getaxes(ps) + + function continuity_loss(u_pred, u_initial) + if !isapprox(norm(u_initial), 1.0, atol=1e-6) || !isapprox(norm(u_pred), 1.0, atol=1e-6) + @warn "Directions during multiple shooting are not in the sphere. Small deviations from unit norm observed:" + @show norm(u_initial), norm(u_pred) + end + return sum(abs2, u_pred - u_initial) + end function loss_multiple_shooting(β::ComponentVector) @@ -149,11 +154,11 @@ function train(data::AD, p = β.θ) end - return multiple_shoot(β.θ, data.directions, data.times, _prob, loss_function, Tsit5(), - group_size; continuity_term) + return multiple_shoot(β.θ, data.directions, data.times, _prob, loss_function, continuity_loss, params.solver, + group_size; continuity_term, sensealg=params.sensealg) end - function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where{AG <: AbstractRegularization} + function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where {AG <: AbstractRegularization} l_ = 0.0 if reg.order==0 From 9fa33541fb8703f74ce186294ac922380ab6cc90 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 19 Jul 2024 21:49:28 -0300 Subject: [PATCH 12/29] Double rotation example with small changes in src --- examples/double_rotation/double_rotation.jl | 73 +++++++++++++-------- src/SphereUDE.jl | 1 + src/plot.jl | 14 +++- src/setup/config.jl | 5 +- src/train.jl | 39 +++++++---- src/types.jl | 6 +- src/utils.jl | 25 ++++++- 7 files changed, 117 insertions(+), 46 deletions(-) diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 27682f1..5587565 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -17,7 +17,7 @@ using Random rng = Random.default_rng() Random.seed!(rng, 666) -function run() +function run(;kappa=50., regs=regs, title="plot") # Total time simulation tspan = [0, 160.0] @@ -31,10 +31,10 @@ times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples)) # Angular velocity ω₀ = Δω₀ * π / 180.0 # Change point -τ₀ = 65.0 +τ₀ = 70.0 # Angular momentum L0 = ω₀ .* [1.0, 0.0, 0.0] -L1 = 0.5ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)] +L1 = 0.6ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)] # Solver tolerances reltol = 1e-12 @@ -49,37 +49,30 @@ function L_true(t::Float64; τ₀=τ₀, p=[L0, L1]) end function true_rotation!(du, u, p, t) - L = L_true(t; τ₀=τ₀, p=p) + L = L_true(t; τ₀ = τ₀, p = p) du .= cross(L, u) end prob = ODEProblem(true_rotation!, [0.0, 0.0, -1.0], tspan, [L0, L1]) -true_sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samples) +true_sol = solve(prob, Tsit5(), reltol = reltol, abstol = abstol, saveat = times_samples) # Add Fisher noise to true solution X_noiseless = Array(true_sol) -X_true = X_noiseless + FisherNoise(kappa=50.) +X_true = X_noiseless + FisherNoise(kappa=kappa) ############################################################## ####################### Training ########################### ############################################################## -data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true) +data = SphereData(times=times_samples, directions=X_true, kappas=nothing, L=L_true) -regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)] -# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode="CS"), -# Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)] -# regs = [Regularization(order=0, power=2.0, λ=0.1, diff_mode=nothing)] - # Regularization(order=1, power=1.1, λ=0.01, diff_mode="CS")] -# regs = nothing - -params = SphereParameters(tmin=tspan[1], tmax=tspan[2], - reg=regs, - train_initial_condition=false, - multiple_shooting=true, - u0=[0.0, 0.0, -1.0], ωmax=ω₀, reltol=reltol, abstol=abstol, - niter_ADAM=1000, niter_LBFGS=600, - sensealg=GaussAdjoint(autojacvec=ReverseDiffVJP(true))) +params = SphereParameters(tmin = tspan[1], tmax = tspan[2], + reg = regs, + train_initial_condition = false, + multiple_shooting = false, + u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol, + niter_ADAM = 2000, niter_LBFGS = 1000, + sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) results = train(data, params, rng, nothing) @@ -87,8 +80,36 @@ results = train(data, params, rng, nothing) ###################### PyCall Plots ######################### ############################################################## -plot_sphere(data, results, -20., 150., saveas="examples/double_rotation/plot_sphere.pdf", title="Double rotation") -plot_L(data, results, saveas="examples/double_rotation/plot_L.pdf", title="Double rotation") - -end -run() \ No newline at end of file +plot_sphere(data, results, -20., 125., saveas="examples/double_rotation/" * title * "_sphere.pdf", title="Double rotation") # , matplotlib_rcParams=Dict("font.size"=> 50)) +plot_L(data, results, saveas="examples/double_rotation/" * title * "_L.pdf", title="Double rotation") + +end # run + +# Run different experiments + +λ₀ = 0.1 +λ₁ = 0.1 +# λ₀ = 0.1 +# λ₁ = 0.01 +run(; kappa = 50., + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)], + title = "plot_50_lambda$(λ₁)") + +λ₀ = 0.1 +λ₁ = 0.1 +# λ₀ = 0.1 +# λ₁ = 0.01 +run(; kappa = 200., + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + Regularization(order=0, power=2.0, λ=λ₀)], + title = "plot_200_lambda$(λ₁)") + +λ₀ = 0.1 +λ₁ = 0.1 +# λ₀ = 0.1 +# λ₁ = 0.01 +run(; kappa = 1000., + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + Regularization(order=0, power=2.0, λ=λ₀)], + title = "plot_1000_lambda$(λ₁)") \ No newline at end of file diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index be87a6b..7a49453 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -28,6 +28,7 @@ include("train.jl") include("plot.jl") # Python libraries +const mpl_base::PyObject = isdefined(SphereUDE, :mpl_base) ? SphereUDE.mpl_base : PyNULL() const mpl_colors::PyObject = isdefined(SphereUDE, :mpl_colors) ? SphereUDE.mpl_colors : PyNULL() const mpl_colormap::PyObject = isdefined(SphereUDE, :mpl_colormap) ? SphereUDE.mpl_colormap : PyNULL() const sns::PyObject = isdefined(SphereUDE, :sns) ? SphereUDE.sns : PyNULL() diff --git a/src/plot.jl b/src/plot.jl index 7deb758..c33df4e 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -23,12 +23,22 @@ function plot_sphere(# ax::PyCall.PyObject, central_latitude::Float64, central_longitude::Float64; saveas::Union{String, Nothing}, - title::String) + title::String, + matplotlib_rcParams::Union{Dict, Nothing} = nothing) # cmap = mpl_colormap.get_cmap("viridis") plt.figure(figsize=(10,10)) ax = plt.axes(projection=ccrs.Orthographic(central_latitude=central_latitude, central_longitude=central_longitude)) + + # Set default plot parameters. + # See https://matplotlib.org/stable/users/explain/customizing.html for customizable optionsz + if !isnothing(matplotlib_rcParams) + for (key, value) in matplotlib_rcParams + @warn "Setting Matplotlib parameters with rcParams currently not working. See following GitHub issue: https://github.com/JuliaPy/PyPlot.jl/issues/525" + mpl_base.rcParams[key] = value + end + end # ax.coastlines() ax.gridlines() @@ -51,7 +61,7 @@ function plot_sphere(# ax::PyCall.PyObject, linewidth=2, color="black",#cmap(norm(results.fit_times[i])), transform = ccrs.Geodetic()) end - plt.title(title) + plt.title(title, fontsize=20) if !isnothing(saveas) plt.savefig(saveas, format="pdf") end diff --git a/src/setup/config.jl b/src/setup/config.jl index 7472e96..824b52d 100644 --- a/src/setup/config.jl +++ b/src/setup/config.jl @@ -1,8 +1,9 @@ -export mpl_colormap, mpl_colormap, sns, ccrs, feature +export mpl_base, mpl_colormap, mpl_colormap, sns, ccrs, feature function __init__() try + copy!(mpl_base, pyimport("matplotlib")) copy!(mpl_colors, pyimport("matplotlib.colors")) copy!(mpl_colormap, pyimport("matplotlib.cm")) copy!(sns, pyimport("seaborn")) @@ -10,7 +11,7 @@ function __init__() copy!(feature, pyimport("cartopy.feature")) catch e @warn "It looks like you have not installed and/or activated the virtual Python environment. \n - Please follow the guidelines in: https://github.com/facusapienza21/SphereUDE.jl#readme" + Please follow the guidelines in: https://github.com/ODINN-SciML/SphereUDE.jl" @warn exception=(e, catch_backtrace()) end diff --git a/src/train.jl b/src/train.jl index 863d869..3d48eac 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,17 +1,29 @@ export train -# For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid function get_NN(params, rng, θ_trained) # Define neural network - U = Lux.Chain( - Lux.Dense(1, 5, sigmoid), # explore discontinuity function for activation - Lux.Dense(5, 10, sigmoid), - Lux.Dense(10, 5, sigmoid), - # Lux.Dense(1, 5, relu_cap), - # Lux.Dense(5, 10, relu_cap), - # Lux.Dense(10, 5, relu_cap), - Lux.Dense(5, 3, x->sigmoid_cap(x; ω₀=params.ωmax)) - ) + + # For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid + if isL1reg(params.reg) + @warn "[SphereUDE] Using ReLU activation functions for neural network due to L1 regularization." + U = Lux.Chain( + Lux.Dense(1, 5, sigmoid), + Lux.Dense(5, 10, sigmoid), + Lux.Dense(10, 5, sigmoid), + # Lux.Dense(1, 5, relu_cap), + # Lux.Dense(5, 10, relu_cap), + # Lux.Dense(10, 5, relu_cap), + Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) + # Lux.Dense(5, 3, x -> relu_cap(x; ω₀=params.ωmax)) + ) + else + U = Lux.Chain( + Lux.Dense(1, 5, sigmoid), + Lux.Dense(5, 10, sigmoid), + Lux.Dense(10, 5, sigmoid), + Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) + ) + end θ, st = Lux.setup(rng, U) return U, θ, st end @@ -67,7 +79,8 @@ function train(data::AD, end sol = solve(_prob, params.solver, saveat=T, abstol=params.abstol, reltol=params.reltol, - sensealg=params.sensealg) + sensealg=params.sensealg, + dtmin=1e-4 * (params.tmax - params.tmin), force_dtmin=true) # Force minimum step in case L(t) changes drastically due to bad behaviour of neural network return Array(sol), sol.retcode end @@ -101,7 +114,7 @@ function train(data::AD, for reg in params.reg reg₀ = regularization(β.θ, reg) l_reg += reg₀ - loss_dict["Regularization (order=$(reg.order), power=$(reg.power)"] = reg₀ + loss_dict["Regularization (order=$(reg.order), power=$(reg.power))"] = reg₀ end end return l_emp + l_reg, loss_dict @@ -211,7 +224,7 @@ function train(data::AD, if params.niter_LBFGS > 0 optprob2 = Optimization.OptimizationProblem(optf, res1.u) - res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) + res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) println("Final training loss after $(length(losses)) iterations: $(losses[end])") else res2 = res1 diff --git a/src/types.jl b/src/types.jl index 8e3783a..315720e 100644 --- a/src/types.jl +++ b/src/types.jl @@ -56,5 +56,9 @@ Regularization information order::I # Order of derivative power::F # Power of the Euclidean norm λ::F # Regularization hyperparameter - diff_mode::Union{Nothing, String} # AD differentiation mode used in regulatization + # AD differentiation mode used in regulatization + diff_mode::Union{Nothing, String} = nothing + + # Include this in the constructor + # @assert (order == 0) || (!isnothing(diff_mode)) "Diffentiation methods needs to be provided for regularization with order larger than zero." end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index ac049a2..c632e19 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -4,7 +4,7 @@ export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation export raise_warnings - +export isL1reg """ sigmoid_cap(x; ω₀=1.0) @@ -41,9 +41,12 @@ end function relu_cap(x; ω₀=1.0) min_value = - ω₀ max_value = + ω₀ - return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) + return relu_cap(x, min_value, max_value) end +function relu_cap(x, min_value::Float64, max_value::Float64) + return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) +end """ relu(x::Complex) @@ -59,6 +62,11 @@ end function relu_cap(z::Complex; ω₀=1.0) min_value = - ω₀ max_value = + ω₀ + return relu_cap(z, min_value, max_value) + # return min_value + (max_value - min_value) * relu(z - relu(z-1)) +end + +function relu_cap(z::Complex, min_value::Float64, max_value::Float64) return min_value + (max_value - min_value) * relu(z - relu(z-1)) end @@ -174,4 +182,17 @@ function raise_warnings(data::SphereData, params::SphereParameters) # end # end nothing +end + +""" + +Function to check for the presence of L1 regularization in the loss function. +""" +function isL1reg(regs::Vector{R}) where {R <: AbstractRegularization} + for reg in regs + if reg.power == 1 + return true + end + end + return false end \ No newline at end of file From e2116c39128915ffe3a4c31755a7b98ebcbc2989 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 21 Sep 2024 15:01:51 -0400 Subject: [PATCH 13/29] feat: update to support Lux 1.0 --- Project.toml | 13 ++++++++----- src/SphereUDE.jl | 4 +--- src/types.jl | 4 ++-- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 3db8f61..6241b8a 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SphereUDE" uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d" authors = ["Facundo Sapienza ", "Jordi Bolibar "] -version = "0.1.1" +version = "0.1.2" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" @@ -19,32 +19,35 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BenchmarkTools = "1" ComponentArrays = "0.15" +DiffEqFlux = "4" Distributions = "0.25" Infiltrator = "1.2" +Lux = "1" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" -OrdinaryDiffEq = "5, 6" +OrdinaryDiffEqCore = "1.6.0" +OrdinaryDiffEqTsit5 = "1.1.0" PyCall = "1.9" PyPlot = "2.11" Revise = "3.1" SciMLSensitivity = "7.20" Statistics = "1" Zygote = "0.6" -julia = "1.7" +julia = "1.10" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index be87a6b..d341fc7 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -1,15 +1,13 @@ __precompile__() module SphereUDE -# types -using Base: @kwdef # utils # training using LinearAlgebra, Statistics, Distributions using FastGaussQuadrature using Lux, Zygote, DiffEqFlux using ChainRules: @ignore_derivatives -using OrdinaryDiffEq +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms using ComponentArrays diff --git a/src/types.jl b/src/types.jl index 8e3783a..fe43787 100644 --- a/src/types.jl +++ b/src/types.jl @@ -23,7 +23,7 @@ Training parameters niter_LBFGS::I reltol::F abstol::F - solver::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5() + solver::OrdinaryDiffEqCore.OrdinaryDiffEqAlgorithm = Tsit5() sensealg::SciMLBase.AbstractAdjointSensitivityAlgorithm = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)) end @@ -43,7 +43,7 @@ Final results @kwdef struct Results{F <: AbstractFloat} <: AbstractResult θ::ComponentVector u0::Vector{F} - U::Lux.Chain + U::Lux.AbstractLuxLayer st::NamedTuple fit_times::Vector{F} fit_directions::Matrix{F} From f4eee7a53bc6495e17e5e48718531968a1d35c9b Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 26 Sep 2024 15:05:05 -0700 Subject: [PATCH 14/29] Example with double rotation working with non-updated Lux --- Project.toml | 3 ++- README.md | 8 ++++++-- examples/double_rotation/double_rotation.jl | 19 ++++++++----------- 3 files changed, 16 insertions(+), 14 deletions(-) diff --git a/Project.toml b/Project.toml index 3db8f61..48bc457 100644 --- a/Project.toml +++ b/Project.toml @@ -34,6 +34,7 @@ BenchmarkTools = "1" ComponentArrays = "0.15" Distributions = "0.25" Infiltrator = "1.2" +Lux = "<0.5.49" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" @@ -50,4 +51,4 @@ julia = "1.7" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] -test = ["Test"] +test = ["Test"] \ No newline at end of file diff --git a/README.md b/README.md index 9ee6936..481ae68 100644 --- a/README.md +++ b/README.md @@ -64,7 +64,11 @@ To make plots using Matplotlib, Cartopy, and PMagPy, we install both [PyCall.jl] - Create a Python conda environment, based on [this conda environment file](https://raw.githubusercontent.com/facusapienza21/SphereUDE.jl/main/environment.yml), with all the required packages using `conda env create -f environment.yml`. - Inside the Julia REPL, install both `PyCall.jl` and `PyPlot.jl` with `] add PyCall, Pyplot`. -- Specify the Python path of the new environment with `ENV["PYTHON"] = ...`, where you should complete the path of the Python installation that shows when you do `conda activate SphereUDE`, `which python`. -- Inside the Julia REPL, execute `Pkg.build("PyCall")` to re-build PyCall with the new Python path. +- Specify the Python path of the new environment with `ENV["PYTHON"] = ...`, where you should complete the path of the Python installation that shows when you do `conda activate SphereUDE`, `which python`. Inside the Julia REPL, execute `Pkg.build("PyCall")` to re-build PyCall with the new Python path: +``` +julia> ENV["PYTHON"] = read(`which python`, String)[1:end-1] # trim backspace +julia> import Pkg; Pkg.build("PyCall") +julia> exit() +``` You are ready to use Python from your Julia session! diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 5587565..597d53a 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -88,28 +88,25 @@ end # run # Run different experiments λ₀ = 0.1 -λ₁ = 0.1 -# λ₀ = 0.1 -# λ₁ = 0.01 +λ₁ = 0.001 + run(; kappa = 50., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="FD"), Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)], - title = "plot_50_lambda$(λ₁)") + title = "plots/plot_50_lambda$(λ₁)") λ₀ = 0.1 λ₁ = 0.1 -# λ₀ = 0.1 -# λ₁ = 0.01 + run(; kappa = 200., regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), Regularization(order=0, power=2.0, λ=λ₀)], - title = "plot_200_lambda$(λ₁)") + title = "plots/plot_200_lambda$(λ₁)") λ₀ = 0.1 λ₁ = 0.1 -# λ₀ = 0.1 -# λ₁ = 0.01 + run(; kappa = 1000., regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), Regularization(order=0, power=2.0, λ=λ₀)], - title = "plot_1000_lambda$(λ₁)") \ No newline at end of file + title = "plots/plot_1000_lambda$(λ₁)") \ No newline at end of file From 8a01ca5a16da67448bf173718d971766b0099075 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 26 Sep 2024 16:18:13 -0700 Subject: [PATCH 15/29] Integration test of inversion --- test/rotation.jl | 63 ++++++++++++++++++++++++++++++++++++++++++++++++ test/runtests.jl | 6 ++++- 2 files changed, 68 insertions(+), 1 deletion(-) create mode 100644 test/rotation.jl diff --git a/test/rotation.jl b/test/rotation.jl new file mode 100644 index 0000000..5e78b1c --- /dev/null +++ b/test/rotation.jl @@ -0,0 +1,63 @@ +using LinearAlgebra, Statistics, Distributions +using OrdinaryDiffEq +using SciMLSensitivity +using Optimization, OptimizationOptimisers, OptimizationOptimJL + +using Random +rng = Random.default_rng() +Random.seed!(rng, 666) + +############################################################## +############### Simulation of Simple Rotation ############### +############################################################## + +function test_single_rotation() + + # Total time simulation + tspan = [0, 160.0] + # Number of sample points + N_samples = 10 + # Times where we sample points + times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples)) + + # Expected maximum angular deviation in one unit of time (degrees) + Δω₀ = 1.0 + # Angular velocity + ω₀ = Δω₀ * π / 180.0 + + # Create simple example + X = zeros(3, N_samples) + X[3, :] .= 1 + X[1, :] = LinRange(0,1,N_samples) + X = X ./ norm.(eachcol(X))' + + ############################################################## + ####################### Training ########################### + ############################################################## + + data = SphereData(times=times_samples, directions=X, kappas=nothing, L=nothing) + + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="FD"), + Regularization(order=0, power=2.0, λ=0.001, diff_mode=nothing)] + + params = SphereParameters(tmin = tspan[1], tmax = tspan[2], + reg = regs, + train_initial_condition = false, + multiple_shooting = false, + u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = 1e-12, abstol = 1e-12, + niter_ADAM = 20, niter_LBFGS = 10, + sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) + + results = train(data, params, rng, nothing) + + @test true + +end + +############################################################## +###################### PyCall Plots ######################### +############################################################## + +# plot_sphere(data, results, -20., 125., saveas="examples/double_rotation/" * title * "_sphere.pdf", title="Double rotation") # , matplotlib_rcParams=Dict("font.size"=> 50)) +# plot_L(data, results, saveas="examples/double_rotation/" * title * "_L.pdf", title="Double rotation") + diff --git a/test/runtests.jl b/test/runtests.jl index f33bf74..db621eb 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -4,7 +4,7 @@ using Lux include("constructors.jl") include("utils.jl") - +include("rotation.jl") @testset "Constructors" begin test_reg_constructor() @@ -16,3 +16,7 @@ end test_complex_activation() test_integration() end + +@testset "Inversion" begin + test_single_rotation() +end \ No newline at end of file From 8ee17ed3454c8139d240a7b6640b6ea32a682dd2 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Thu, 26 Sep 2024 16:37:07 -0700 Subject: [PATCH 16/29] Added Random as test dependency --- Project.toml | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 48bc457..fe0c86b 100644 --- a/Project.toml +++ b/Project.toml @@ -49,6 +49,7 @@ julia = "1.7" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" [targets] -test = ["Test"] \ No newline at end of file +test = ["Test", "Random"] From aad95b41ebe07bf01cda2e001de59aab8d3e0804 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 27 Sep 2024 14:15:20 -0700 Subject: [PATCH 17/29] Fix Lux version Co-authored-by: Avik Pal --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index ddc6c1f..4b3f55d 100644 --- a/Project.toml +++ b/Project.toml @@ -35,7 +35,7 @@ ComponentArrays = "0.15" DiffEqFlux = "4" Distributions = "0.25" Infiltrator = "1.2" -Lux = "1" +Lux = "1.0" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" From 9711584e39d6e8dc993acecef6c6402b18edebe1 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 27 Sep 2024 14:17:40 -0700 Subject: [PATCH 18/29] [WIP] Working around `Lux=1` (#108) * Update Project with new dependencies * remove FiniteDifferences from dependencies * Example of APWP fit based on Jupp1987 * add complex-step method * Double differentiation working with complex-step method * Initial condition u0 fitting implemented * Projected gradient descent working for u0. Some more tests * Testing activation functions with complex-step * predict function, return multiple losses * Co-authored-by: Jordi Bolibar * Multiple shooting working once sensealg specified * Double rotation example with small changes in src * feat: update to support Lux 1.0 * Example with double rotation working with non-updated Lux * Integration test of inversion * Added Random as test dependency * Fix Lux version Co-authored-by: Avik Pal --------- Co-authored-by: Avik Pal --- Project.toml | 14 ++++++++------ src/SphereUDE.jl | 4 +--- src/types.jl | 4 ++-- 3 files changed, 11 insertions(+), 11 deletions(-) diff --git a/Project.toml b/Project.toml index fe0c86b..4b3f55d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "SphereUDE" uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d" authors = ["Facundo Sapienza ", "Jordi Bolibar "] -version = "0.1.1" +version = "0.1.2" [deps] BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" @@ -19,33 +19,35 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" -OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed" +OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" +OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1" Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" -TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77" Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" [compat] BenchmarkTools = "1" ComponentArrays = "0.15" +DiffEqFlux = "4" Distributions = "0.25" Infiltrator = "1.2" -Lux = "<0.5.49" +Lux = "1.0" Optimization = "3.12" OptimizationOptimJL = "0.1.5" OptimizationOptimisers = "0.1.2" -OrdinaryDiffEq = "5, 6" +OrdinaryDiffEqCore = "1.6.0" +OrdinaryDiffEqTsit5 = "1.1.0" PyCall = "1.9" PyPlot = "2.11" Revise = "3.1" SciMLSensitivity = "7.20" Statistics = "1" Zygote = "0.6" -julia = "1.7" +julia = "1.10" [extras] Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index 7a49453..d534790 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -1,15 +1,13 @@ __precompile__() module SphereUDE -# types -using Base: @kwdef # utils # training using LinearAlgebra, Statistics, Distributions using FastGaussQuadrature using Lux, Zygote, DiffEqFlux using ChainRules: @ignore_derivatives -using OrdinaryDiffEq +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms using ComponentArrays diff --git a/src/types.jl b/src/types.jl index 315720e..e34792f 100644 --- a/src/types.jl +++ b/src/types.jl @@ -23,7 +23,7 @@ Training parameters niter_LBFGS::I reltol::F abstol::F - solver::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5() + solver::OrdinaryDiffEqCore.OrdinaryDiffEqAlgorithm = Tsit5() sensealg::SciMLBase.AbstractAdjointSensitivityAlgorithm = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)) end @@ -43,7 +43,7 @@ Final results @kwdef struct Results{F <: AbstractFloat} <: AbstractResult θ::ComponentVector u0::Vector{F} - U::Lux.Chain + U::Lux.AbstractLuxLayer st::NamedTuple fit_times::Vector{F} fit_directions::Matrix{F} From d1a1fcd4bd0be31cdbeaa421e53a5115f2f60274 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 27 Sep 2024 23:18:48 +0200 Subject: [PATCH 19/29] CI on `up-lux` branch --- .github/workflows/CI.yml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index 60ecac4..d243331 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -3,10 +3,12 @@ on: push: branches: - main + - up-lux tags: ['*'] pull_request: branches: - main + - up-lux workflow_dispatch: concurrency: # Skip intermediate builds: always. From 62848d9f56d3bf452229e3518dac5b462a81d02d Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 27 Sep 2024 23:40:28 +0200 Subject: [PATCH 20/29] Update CI.yml - Update CI version --- .github/workflows/CI.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/CI.yml b/.github/workflows/CI.yml index d243331..2aefe6e 100644 --- a/.github/workflows/CI.yml +++ b/.github/workflows/CI.yml @@ -26,7 +26,7 @@ jobs: fail-fast: false matrix: version: - - '1.9' + - '1' # - 'nightly' python: - 3.9 From 003ef375034290361b6f340ff2a347428d0e9eee Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Sat, 28 Sep 2024 02:35:49 +0200 Subject: [PATCH 21/29] Remove OrdinaryDiffEq from test --- test/rotation.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rotation.jl b/test/rotation.jl index 5e78b1c..7732f82 100644 --- a/test/rotation.jl +++ b/test/rotation.jl @@ -1,5 +1,5 @@ using LinearAlgebra, Statistics, Distributions -using OrdinaryDiffEq +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL From fce57574cf504aa00f1bfc4bffeb9a5ac4ae022d Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Sat, 28 Sep 2024 17:34:30 -0400 Subject: [PATCH 22/29] feat: update to support Lux 1.0 (#94) * feat: update to support Lux 1.0 * ci: up to 1.10 * fix: make the activations more Lux friendly --- src/train.jl | 6 +++--- src/utils.jl | 8 ++++++-- test/rotation.jl | 3 +-- test/runtests.jl | 2 +- 4 files changed, 11 insertions(+), 8 deletions(-) diff --git a/src/train.jl b/src/train.jl index 3d48eac..91ebf78 100644 --- a/src/train.jl +++ b/src/train.jl @@ -13,15 +13,15 @@ function get_NN(params, rng, θ_trained) # Lux.Dense(1, 5, relu_cap), # Lux.Dense(5, 10, relu_cap), # Lux.Dense(10, 5, relu_cap), - Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) - # Lux.Dense(5, 3, x -> relu_cap(x; ω₀=params.ωmax)) + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) + # Lux.Dense(5, 3, Base.Fix2(relu_cap, params.ωmax)) ) else U = Lux.Chain( Lux.Dense(1, 5, sigmoid), Lux.Dense(5, 10, sigmoid), Lux.Dense(10, 5, sigmoid), - Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax)) + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) end θ, st = Lux.setup(rng, U) diff --git a/src/utils.jl b/src/utils.jl index c632e19..320ecdc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -11,7 +11,9 @@ export isL1reg Normalization of the neural network last layer """ -function sigmoid_cap(x; ω₀=1.0) +sigmoid_cap(x; ω₀=1.0) = sigmoid_cap(x, ω₀) + +function sigmoid_cap(x, ω₀) min_value = - ω₀ max_value = + ω₀ return min_value + (max_value - min_value) * sigmoid(x) @@ -38,7 +40,9 @@ end """ relu_cap(x; ω₀=1.0) """ -function relu_cap(x; ω₀=1.0) +relu_cap(x; ω₀=1.0) = relu_cap(x, ω₀) + +function relu_cap(x, ω₀) min_value = - ω₀ max_value = + ω₀ return relu_cap(x, min_value, max_value) diff --git a/test/rotation.jl b/test/rotation.jl index 7732f82..4601eff 100644 --- a/test/rotation.jl +++ b/test/rotation.jl @@ -1,5 +1,4 @@ -using LinearAlgebra, Statistics, Distributions -using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 +using LinearAlgebra, Statistics, Distributions using SciMLSensitivity using Optimization, OptimizationOptimisers, OptimizationOptimJL diff --git a/test/runtests.jl b/test/runtests.jl index db621eb..4bbdf6b 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -19,4 +19,4 @@ end @testset "Inversion" begin test_single_rotation() -end \ No newline at end of file +end From a48195909f3d7918219df27da21a94312f8dcab3 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Mon, 30 Sep 2024 16:08:36 -0700 Subject: [PATCH 23/29] Define abstract types for different differentiation modes --- examples/double_rotation/double_rotation.jl | 16 +++++++------- src/train.jl | 23 ++++++++++++--------- src/types.jl | 19 ++++++++++++++++- test/constructors.jl | 10 ++++----- test/rotation.jl | 2 +- 5 files changed, 46 insertions(+), 24 deletions(-) diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 597d53a..4890388 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -2,8 +2,8 @@ using Pkg; Pkg.activate(".") using Revise using LinearAlgebra, Statistics, Distributions -using OrdinaryDiffEq using SciMLSensitivity +using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using Optimization, OptimizationOptimisers, OptimizationOptimJL using SphereUDE @@ -71,7 +71,7 @@ params = SphereParameters(tmin = tspan[1], tmax = tspan[2], train_initial_condition = false, multiple_shooting = false, u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol, - niter_ADAM = 2000, niter_LBFGS = 1000, + niter_ADAM = 2000, niter_LBFGS = 2000, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) results = train(data, params, rng, nothing) @@ -87,19 +87,21 @@ end # run # Run different experiments +ϵ = 1e-5 + λ₀ = 0.1 -λ₁ = 0.001 +λ₁ = 0.01 run(; kappa = 50., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="FD"), - Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)], + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), + Regularization(order=0, power=2.0, λ=λ₀)], title = "plots/plot_50_lambda$(λ₁)") λ₀ = 0.1 λ₁ = 0.1 run(; kappa = 200., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), Regularization(order=0, power=2.0, λ=λ₀)], title = "plots/plot_200_lambda$(λ₁)") @@ -107,6 +109,6 @@ run(; kappa = 200., λ₁ = 0.1 run(; kappa = 1000., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"), + regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), Regularization(order=0, power=2.0, λ=λ₀)], title = "plots/plot_1000_lambda$(λ₁)") \ No newline at end of file diff --git a/src/train.jl b/src/train.jl index 91ebf78..fa2911c 100644 --- a/src/train.jl +++ b/src/train.jl @@ -175,25 +175,28 @@ function train(data::AD, l_ = 0.0 if reg.order==0 + l_ += quadrature(t -> norm(predict_L(t, U, θ, st))^reg.power, params.tmin, params.tmax, n_nodes) + elseif reg.order==1 - if reg.diff_mode=="AD" + + if typeof(reg.diff_mode) <: LuxNestedAD throw("Method not working well.") - # Compute gradient using automatic differentiaion in the NN - # This currently doesn't run... too slow. - - # Test this with the new implementation in Lux.jl: - # https://lux.csail.mit.edu/stable/manual/nested_autodiff - elseif reg.diff_mode=="FD" + + elseif typeof(reg.diff_mode) <: FiniteDifferences # Finite differences - ϵ = 0.1 * (params.tmax - params.tmin) / n_nodes + ϵ = reg.diff_mode.ϵ l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) - elseif reg.diff_mode=="CS" + + elseif typeof(reg.diff_mode) <: ComplexStepDifferentiation # Complex step differentiation - l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t))^reg.power, params.tmin, params.tmax, n_nodes) + ϵ = reg.diff_mode.ϵ + l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t, ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + else throw("Method not implemented.") end + else throw("Method not implemented.") end diff --git a/src/types.jl b/src/types.jl index e34792f..29d3af6 100644 --- a/src/types.jl +++ b/src/types.jl @@ -2,11 +2,13 @@ export SphereParameters, AbstractParameters export SphereData, AbstractData export Results, AbstractResult export Regularization, AbstractRegularization +export FiniteDifferences, ComplexStepDifferentiation, LuxNestedAD, AbstractDifferentiation abstract type AbstractParameters end abstract type AbstractData end abstract type AbstractRegularization end abstract type AbstractResult end +abstract type AbstractDifferentiation end """ Training parameters @@ -57,8 +59,23 @@ Regularization information power::F # Power of the Euclidean norm λ::F # Regularization hyperparameter # AD differentiation mode used in regulatization - diff_mode::Union{Nothing, String} = nothing + diff_mode::Union{Nothing, AbstractDifferentiation} = nothing # Include this in the constructor # @assert (order == 0) || (!isnothing(diff_mode)) "Diffentiation methods needs to be provided for regularization with order larger than zero." +end + +""" +Differentiation methods +""" +@kwdef struct FiniteDifferences{F <: AbstractFloat} <: AbstractDifferentiation + ϵ::F = 1e-10 +end + +@kwdef struct ComplexStepDifferentiation{F <: AbstractFloat} <: AbstractDifferentiation + ϵ::F = 1e-10 +end + +@kwdef struct LuxNestedAD <: AbstractDifferentiation + method::Union{Nothing, String} = nothing end \ No newline at end of file diff --git a/test/constructors.jl b/test/constructors.jl index 4bd2b41..d37855d 100644 --- a/test/constructors.jl +++ b/test/constructors.jl @@ -1,12 +1,12 @@ function test_reg_constructor() - reg = Regularization(order=1, power=2.0, λ=0.1, diff_mode="AD") + reg = Regularization(order=1, power=2.0, λ=0.1, diff_mode=ComplexStepDifferentiation()) @test reg.order == 1 @test reg.power == 2.0 @test reg.λ == 0.1 - @test reg.diff_mode == "AD" + @test typeof(reg.diff_mode) <: AbstractDifferentiation end @@ -21,8 +21,8 @@ function test_param_constructor() niter_ADAM=1000, niter_LBFGS=300, reltol=1e6, abstol=1e-6) - reg1 = Regularization(order=0, power=2.0, λ=0.1, diff_mode="AD") - reg2 = Regularization(order=1, power=1.0, λ=0.1, diff_mode="AD") + reg1 = Regularization(order=0, power=2.0, λ=0.1, diff_mode=FiniteDifferences()) + reg2 = Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()) params2 = SphereParameters(tmin=0.0, tmax=100., u0=[0. ,0. ,1.], @@ -37,6 +37,6 @@ function test_param_constructor() @test params.tmax == 100.0 @test params2.reg[1].order == 0 - @test params2.reg[2].diff_mode == "AD" + @test typeof(params2.reg[2].diff_mode) <: LuxNestedAD end \ No newline at end of file diff --git a/test/rotation.jl b/test/rotation.jl index 4601eff..43b8b55 100644 --- a/test/rotation.jl +++ b/test/rotation.jl @@ -36,7 +36,7 @@ function test_single_rotation() data = SphereData(times=times_samples, directions=X, kappas=nothing, L=nothing) - regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode="FD"), + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=FiniteDifferences(1e-6)), Regularization(order=0, power=2.0, λ=0.001, diff_mode=nothing)] params = SphereParameters(tmin = tspan[1], tmax = tspan[2], From a3a93485eb2d6b94414dfac7d8f2737955b28949 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Tue, 1 Oct 2024 13:43:05 -0700 Subject: [PATCH 24/29] Implemented regularization with Lux nested AD --- Project.toml | 3 ++- src/SphereUDE.jl | 2 +- src/train.jl | 45 +++++++++++++++++++++++++++++++++++---------- src/types.jl | 2 +- src/utils.jl | 12 ++++++++++-- 5 files changed, 49 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 4b3f55d..4181537 100644 --- a/Project.toml +++ b/Project.toml @@ -12,6 +12,7 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" @@ -50,8 +51,8 @@ Zygote = "0.6" julia = "1.10" [extras] -Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" [targets] test = ["Test", "Random"] diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index d534790..55ff797 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -8,7 +8,7 @@ using FastGaussQuadrature using Lux, Zygote, DiffEqFlux using ChainRules: @ignore_derivatives using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 -using SciMLSensitivity +using SciMLSensitivity, ForwardDiff using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms using ComponentArrays using PyPlot, PyCall diff --git a/src/train.jl b/src/train.jl index fa2911c..d7b6f0a 100644 --- a/src/train.jl +++ b/src/train.jl @@ -9,18 +9,22 @@ function get_NN(params, rng, θ_trained) U = Lux.Chain( Lux.Dense(1, 5, sigmoid), Lux.Dense(5, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), Lux.Dense(10, 5, sigmoid), - # Lux.Dense(1, 5, relu_cap), - # Lux.Dense(5, 10, relu_cap), - # Lux.Dense(10, 5, relu_cap), + # Lux.Dense(1, 5, Lux.relu), + # Lux.Dense(5, 10, Lux.relu), + # Lux.Dense(10, 5, Lux.relu), Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) # Lux.Dense(5, 3, Base.Fix2(relu_cap, params.ωmax)) ) else U = Lux.Chain( - Lux.Dense(1, 5, sigmoid), - Lux.Dense(5, 10, sigmoid), - Lux.Dense(10, 5, sigmoid), + Lux.Dense(1, 5, gelu), + Lux.Dense(5, 10, gelu), + Lux.Dense(10, 10, gelu), + Lux.Dense(10, 10, gelu), + Lux.Dense(10, 5, gelu), Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) end @@ -51,6 +55,8 @@ function train(data::AD, raise_warnings(data::AD, params::AP) U, θ, st = get_NN(params, rng, θ_trained) + # Make it a stateful layer (this I don't know where is best to add it, it repeats) + smodel = StatefulLuxLayer{true}(U, θ, st) # Set component vector for Optimization if params.train_initial_condition @@ -89,7 +95,7 @@ function train(data::AD, # If numerical integration fails or bad choice of parameter, return infinity if retcode != :Success - @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter." + @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter. This can be due to just a bad initial condition of the neural network, so it is worth changing the randon number used for initialization. " return Inf end @@ -102,7 +108,7 @@ function train(data::AD, l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else - l_emp = mean(data.kappas .* abs2.(u_ .- data.directions), dims=1) + l_emp = mean(data.kappas .* sum(abs2.(u_ .- data.directions), dims=1)) # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) end loss_dict["Empirical"] = l_emp @@ -181,7 +187,24 @@ function train(data::AD, elseif reg.order==1 if typeof(reg.diff_mode) <: LuxNestedAD - throw("Method not working well.") + # Automatic Differentiation + nodes, weights = quadrature(params.tmin, params.tmax, n_nodes) + + if reg.diff_mode.method == "ForwardDiff" + # Jac = ForwardDiff.jacobian(smodel, reshape(nodes, 1, n_nodes)) + Jac = batched_jacobian(smodel, AutoForwardDiff(), reshape(nodes, 1, n_nodes)) + elseif reg.diff_mode.method == "Zygote" + # This can also be done with Zygote in reverse mode + # Jac = Zygote.jacobian(smodel, reshape(nodes, 1, n_nodes))[1] + Jac = batched_jacobian(smodel, AutoZygote(), reshape(nodes, 1, n_nodes)) + else + throw("Method for AD backend no implemented.") + end + + # Compute the final agregation to the loss + for j in 1:n_nodes + l_ += weights[j] * norm(Jac[:,1,j])^reg.power + end elseif typeof(reg.diff_mode) <: FiniteDifferences # Finite differences @@ -206,7 +229,7 @@ function train(data::AD, losses = Float64[] callback = function (p, l) push!(losses, l) - if length(losses) % 50 == 0 + if length(losses) % 20 == 0 println("Current loss after $(length(losses)) iterations: $(losses[end])") end if params.train_initial_condition @@ -215,6 +238,8 @@ function train(data::AD, return false end + println("Loss after initalization: ", loss(β)[1]) + # Dispatch the right loss function f_loss = params.multiple_shooting ? loss_multiple_shooting : loss diff --git a/src/types.jl b/src/types.jl index 29d3af6..adfc248 100644 --- a/src/types.jl +++ b/src/types.jl @@ -77,5 +77,5 @@ end end @kwdef struct LuxNestedAD <: AbstractDifferentiation - method::Union{Nothing, String} = nothing + method::Union{Nothing, String} = "ForwardDiff" end \ No newline at end of file diff --git a/src/utils.jl b/src/utils.jl index 320ecdc..8fc2377 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -139,13 +139,18 @@ end Numerical integral using Gaussian quadrature """ function quadrature(f::Function, t₀, t₁, n_nodes::Int) + nodes, weigths = quadrature(t₀, t₁, n_nodes) + return dot(weigths, f.(nodes)) +end + +function quadrature(t₀, t₁, n_nodes::Int) ignore() do # Ignore AD here since FastGaussQuadrature is using mutating arrays nodes, weigths = gausslegendre(n_nodes) end nodes = (t₀+t₁)/2 .+ nodes * (t₁-t₀)/2 weigths = (t₁-t₀) / 2 * weigths - return dot(weigths, f.(nodes)) + return nodes, weigths end """ @@ -192,7 +197,10 @@ end Function to check for the presence of L1 regularization in the loss function. """ -function isL1reg(regs::Vector{R}) where {R <: AbstractRegularization} +function isL1reg(regs::Union{Vector{R}, Nothing}) where {R <: AbstractRegularization} + if isnothing(regs) + return false + end for reg in regs if reg.power == 1 return true From c6d33df43cc5785b9734fc75dd0dfeff43ca75a9 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Wed, 2 Oct 2024 19:11:11 -0700 Subject: [PATCH 25/29] [WIP] Torsvik + working on complex-step with Lux --- Project.toml | 1 + examples/Torsvik_2012/APWP-Torsvik.jl | 67 +++++++++ src/SphereUDE.jl | 3 +- src/train.jl | 201 ++++++++++++++++---------- src/types.jl | 5 +- src/utils.jl | 53 ++++++- 6 files changed, 241 insertions(+), 89 deletions(-) create mode 100644 examples/Torsvik_2012/APWP-Torsvik.jl diff --git a/Project.toml b/Project.toml index 4181537..1d10b12 100644 --- a/Project.toml +++ b/Project.toml @@ -14,6 +14,7 @@ Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" +LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" diff --git a/examples/Torsvik_2012/APWP-Torsvik.jl b/examples/Torsvik_2012/APWP-Torsvik.jl new file mode 100644 index 0000000..2cf7aab --- /dev/null +++ b/examples/Torsvik_2012/APWP-Torsvik.jl @@ -0,0 +1,67 @@ +using Pkg; Pkg.activate(".") +using Revise + +using LinearAlgebra, Statistics, Distributions +using SciMLSensitivity +# using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 +using Optimization, OptimizationOptimisers, OptimizationOptimJL + +using SphereUDE + +# Random seed +using Random +rng = Random.default_rng() +Random.seed!(rng, 613) + + +using DataFrames, CSV + +df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame, delim=",") + +# Filter the plates that were once part of the supercontinent Gondwana + +Gondwana = ["Amazonia", "Parana", "Colorado", "Southern_Africa", + "East_Antarctica", "Madagascar", "Patagonia", "Northeast_Africa", + "Northwest_Africa", "Somalia", "Arabia"] + +df = filter(row -> row.Plate ∈ Gondwana, df) +df.Times = df.Age .+= rand(sampler(Normal(0,0.1)), nrow(df)) # Needs to fix this! + +df = sort(df, :Times) +times = df.Times + +# Fill missing values +df.RLat .= coalesce.(df.RLat, df.Lat) +df.RLon .= coalesce.(df.RLon, df.Lon) + +X = sph2cart(Matrix(df[:,["RLat","RLon"]])'; radians=false) + +# Retrieve uncertanties from poles and convert α95 into κ +kappas = (140.0 ./ df.a95).^2 + +data = SphereData(times=times, directions=X, kappas=kappas, L=nothing) + +# Training + +# Expected maximum angular deviation in one unit of time (degrees) +Δω₀ = 1.5 +# Angular velocity +ω₀ = Δω₀ * π / 180.0 + +tspan = [times[begin], times[end]] + +params = SphereParameters(tmin = tspan[1], tmax = tspan[2], + # reg = [Regularization(order=1, power=2.0, λ=1e2, diff_mode=ComplexStepDifferentiation())], + reg = nothing, + pretrain = false, + u0 = [0.0, 0.0, -1.0], ωmax = ω₀, + reltol = 1e-8, abstol = 1e-8, + niter_ADAM = 3000, niter_LBFGS = 8000, + sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) + + + +results = train(data, params, rng, nothing) + +plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation") # , matplotlib_rcParams=Dict("font.size"=> 50)) +plot_L(data, results, saveas="examples/Torsvik_2012/plots/plot_L.pdf", title="Double rotation") diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index 55ff797..43621df 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -9,7 +9,8 @@ using Lux, Zygote, DiffEqFlux using ChainRules: @ignore_derivatives using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using SciMLSensitivity, ForwardDiff -using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms +using Optimization, OptimizationOptimisers, OptimizationOptimJL +using OptimizationPolyalgorithms, LineSearches using ComponentArrays using PyPlot, PyCall using PrettyTables diff --git a/src/train.jl b/src/train.jl index d7b6f0a..d7a72d3 100644 --- a/src/train.jl +++ b/src/train.jl @@ -9,23 +9,19 @@ function get_NN(params, rng, θ_trained) U = Lux.Chain( Lux.Dense(1, 5, sigmoid), Lux.Dense(5, 10, sigmoid), - Lux.Dense(10, 10, sigmoid), - Lux.Dense(10, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), + Lux.Dense(10, 10, sigmoid), Lux.Dense(10, 5, sigmoid), - # Lux.Dense(1, 5, Lux.relu), - # Lux.Dense(5, 10, Lux.relu), - # Lux.Dense(10, 5, Lux.relu), - Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) - # Lux.Dense(5, 3, Base.Fix2(relu_cap, params.ωmax)) + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) else U = Lux.Chain( Lux.Dense(1, 5, gelu), Lux.Dense(5, 10, gelu), - Lux.Dense(10, 10, gelu), - Lux.Dense(10, 10, gelu), + Lux.Dense(10, 10, gelu), + Lux.Dense(10, 10, gelu), Lux.Dense(10, 5, gelu), - Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) + Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) end θ, st = Lux.setup(rng, U) @@ -90,33 +86,23 @@ function train(data::AD, return Array(sol), sol.retcode end - function loss(β::ComponentVector) - u_, retcode = predict(β) - - # If numerical integration fails or bad choice of parameter, return infinity - if retcode != :Success - @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter. This can be due to just a bad initial condition of the neural network, so it is worth changing the randon number used for initialization. " - return Inf - end + ##### Definition of loss functions to be used ##### + """ + General Loss Function + """ + function loss(β::ComponentVector) + # Record the value of each individual loss to the total loss function for hyperparameter selection. loss_dict = Dict() - - # Empirical error - if isnothing(data.kappas) - # The 3 is needed since the mean is computen on a 3xN matrix - l_emp = 3.0 * mean(abs2.(u_ .- data.directions)) - # l_emp = 1 - 3.0 * mean(u_ .* data.directions) - else - l_emp = mean(data.kappas .* sum(abs2.(u_ .- data.directions), dims=1)) - # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) - end + + l_emp = loss_empirical(β) + loss_dict["Empirical"] = l_emp # Regularization l_reg = 0.0 if !isnothing(params.reg) - # for (order, power, λ) in params.reg for reg in params.reg reg₀ = regularization(β.θ, reg) l_reg += reg₀ @@ -126,57 +112,34 @@ function train(data::AD, return l_emp + l_reg, loss_dict end - # Loss function to be called for multiple shooting - function loss_function(data, pred) - - # Empirical error - l_emp = 3.0 * mean(abs2.(pred .- data)) - - # Regularization - l_reg = 0.0 - if !isnothing(params.reg) - for reg in params.reg - reg₀ = regularization(β.θ, reg) - l_reg += reg₀ - end - end - - return l_emp + l_reg - end - - # Define parameters for Multiple Shooting - group_size = 10 - continuity_term = 100 - - ps = ComponentArray(θ) # are these necesary? - pd, pax = getdata(ps), getaxes(ps) + """ + Empirical loss function + """ + function loss_empirical(β::ComponentVector) - function continuity_loss(u_pred, u_initial) - if !isapprox(norm(u_initial), 1.0, atol=1e-6) || !isapprox(norm(u_pred), 1.0, atol=1e-6) - @warn "Directions during multiple shooting are not in the sphere. Small deviations from unit norm observed:" - @show norm(u_initial), norm(u_pred) + u_, retcode = predict(β) + + # If numerical integration fails or bad choice of parameter, return infinity + if retcode != :Success + @warn "[SphereUDE] Numerical solver not converging. This can be causes by numerical innestabilities around a bad choice of parameter. This can be due to just a bad initial condition of the neural network, so it is worth changing the randon number used for initialization. " + return Inf end - return sum(abs2, u_pred - u_initial) - end - function loss_multiple_shooting(β::ComponentVector) - - ps = ComponentArray(β.θ, pax) - - if params.train_initial_condition - _prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforce the norm=1 condition again here - tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)), - p = β.θ) + # Empirical error + if isnothing(data.kappas) + # The 3 is needed since the mean is computen on a 3xN matrix + return 3.0 * mean(abs2.(u_ .- data.directions)) + # l_emp = 1 - 3.0 * mean(u_ .* data.directions) else - _prob = remake(prob_nn, u0=params.u0, - tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)), - p = β.θ) + return mean(data.kappas .* sum(abs2.(u_ .- data.directions), dims=1)) + # l_emp = norm(data.kappas)^2 - 3.0 * mean(data.kappas .* u_ .* data.directions) end - - return multiple_shoot(β.θ, data.directions, data.times, _prob, loss_function, continuity_loss, params.solver, - group_size; continuity_term, sensealg=params.sensealg) + end + """ + Regularization + """ function regularization(θ::ComponentVector, reg::AG; n_nodes=100) where {AG <: AbstractRegularization} l_ = 0.0 @@ -208,13 +171,11 @@ function train(data::AD, elseif typeof(reg.diff_mode) <: FiniteDifferences # Finite differences - ϵ = reg.diff_mode.ϵ - l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, ϵ=ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, reg.diff_mode.ϵ))^reg.power, params.tmin, params.tmax, n_nodes) elseif typeof(reg.diff_mode) <: ComplexStepDifferentiation # Complex step differentiation - ϵ = reg.diff_mode.ϵ - l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t, ϵ))^reg.power, params.tmin, params.tmax, n_nodes) + l_ += quadrature(t -> norm(complex_step_differentiation(τ -> predict_L(τ, U, θ, st), t, reg.diff_mode.ϵ))^reg.power, params.tmin, params.tmax, n_nodes) else throw("Method not implemented.") @@ -226,6 +187,62 @@ function train(data::AD, return reg.λ * l_ end + """ + Loss function to be called for multiple shooting + This seems duplicated from before, so be careful with this + """ + function _loss_multiple_shooting(data, pred) + + # Empirical error + l_emp = 3.0 * mean(abs2.(pred .- data)) + + # Regularization + l_reg = 0.0 + if !isnothing(params.reg) + for reg in params.reg + reg₀ = regularization(β.θ, reg) + l_reg += reg₀ + end + end + + return l_emp + l_reg + end + + # Define parameters for Multiple Shooting + group_size = 10 + continuity_term = 100 + + ps = ComponentArray(θ) # are these necesary? + pd, pax = getdata(ps), getaxes(ps) + + function continuity_loss(u_pred, u_initial) + if !isapprox(norm(u_initial), 1.0, atol=1e-6) || !isapprox(norm(u_pred), 1.0, atol=1e-6) + @warn "Directions during multiple shooting are not in the sphere. Small deviations from unit norm observed:" + @show norm(u_initial), norm(u_pred) + end + return sum(abs2, u_pred - u_initial) + end + + function loss_multiple_shooting(β::ComponentVector) + + ps = ComponentArray(β.θ, pax) + + if params.train_initial_condition + _prob = remake(prob_nn, u0=β.u0 / norm(β.u0), # We enforce the norm=1 condition again here + tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)), + p = β.θ) + else + _prob = remake(prob_nn, u0=params.u0, + tspan=(min(data.times[1], params.tmin), max(data.times[end], params.tmax)), + p = β.θ) + end + + return multiple_shoot(β.θ, data.directions, data.times, _prob, _loss_multiple_shooting, continuity_loss, params.solver, + group_size; continuity_term, sensealg=params.sensealg) + end + + + ### Callback function losses = Float64[] callback = function (p, l) push!(losses, l) @@ -243,16 +260,42 @@ function train(data::AD, # Dispatch the right loss function f_loss = params.multiple_shooting ? loss_multiple_shooting : loss + # Optimization setting with AD adtype = Optimization.AutoZygote() + + """ + Pretraining to find parameters without impossing regularization + """ + # if params.pretrain + # losses_pretrain = Float64[] + # callback_pretrain = function(p, l) + # push!(losses_pretrain, l) + # if length(losses_pretrain) % 100 == 0 + # println("[Pretrain with no regularization] Current loss after $(length(losses_pretrain)) iterations: $(losses_pretrain[end])") + # end + # return false + # end + # optf₀ = Optimization.OptimizationFunction((x, β) -> loss_empirical(x), adtype) + # optprob₀ = Optimization.OptimizationProblem(optf₀, β) + # res₀ = Optimization.solve(optprob₀, ADAM(), callback=callback_pretrain, maxiters=params.niter_ADAM, verbose=false) + # optprob₁ = Optimization.OptimizationProblem(optf₀, res₀.u) + # res₁ = Optimization.solve(optprob₁, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback_pretrain, maxiters=params.niter_LBFGS) + # β = res₁.u + # end + + # To do: implement this with polyoptimizaion to put ADAM and BFGS in one step. + # Maybe better to keep like this for the line search. + optf = Optimization.OptimizationFunction((x, β) -> (first ∘ f_loss)(x), adtype) optprob = Optimization.OptimizationProblem(optf, β) - res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=false) + res1 = Optimization.solve(optprob, ADAM(), callback=callback, maxiters=params.niter_ADAM, verbose=true) println("Training loss after $(length(losses)) iterations: $(losses[end])") if params.niter_LBFGS > 0 optprob2 = Optimization.OptimizationProblem(optf, res1.u) - res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) + # res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) + res2 = Optimization.solve(optprob2, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback, maxiters=params.niter_LBFGS) #, reltol=1e-6) println("Final training loss after $(length(losses)) iterations: $(losses[end])") else res2 = res1 diff --git a/src/types.jl b/src/types.jl index adfc248..7cc9ce2 100644 --- a/src/types.jl +++ b/src/types.jl @@ -19,14 +19,15 @@ Training parameters u0::Union{Vector{F}, Nothing} ωmax::F reg::Union{Nothing, Array} - train_initial_condition::Bool - multiple_shooting::Bool + train_initial_condition::Bool = false + multiple_shooting::Bool = false niter_ADAM::I niter_LBFGS::I reltol::F abstol::F solver::OrdinaryDiffEqCore.OrdinaryDiffEqAlgorithm = Tsit5() sensealg::SciMLBase.AbstractAdjointSensitivityAlgorithm = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)) + pretrain::Bool = false end """ diff --git a/src/utils.jl b/src/utils.jl index 8fc2377..1aaf52c 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,11 +1,18 @@ export sigmoid, sigmoid_cap export relu, relu_cap +export gelu export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation export raise_warnings export isL1reg +# Import activation function for complex extension +import Lux: relu, gelu +# import Lux: sigmoid, relu, gelu + +### Custom Activation Funtions + """ sigmoid_cap(x; ω₀=1.0) @@ -19,13 +26,10 @@ function sigmoid_cap(x, ω₀) return min_value + (max_value - min_value) * sigmoid(x) end +# For some reason, when I import the Lux.sigmoid function this train badly, +# increasing the value of the loss function over iterations... function sigmoid(x) return 1.0 / (1.0 + exp(-x)) -# if x > 0 -# return 1 / ( 1.0 + exp(-x) ) -# else -# return exp(x) / (1.0 + exp(x)) -# end end function sigmoid(z::Complex) @@ -52,6 +56,8 @@ function relu_cap(x, min_value::Float64, max_value::Float64) return min_value + (max_value - min_value) * max(0.0, min(x, 1.0)) end +### Complex Expansion Activation Functions + """ relu(x::Complex) @@ -70,10 +76,39 @@ function relu_cap(z::Complex; ω₀=1.0) # return min_value + (max_value - min_value) * relu(z - relu(z-1)) end +""" + relu_cap(z::Complex, min_value::Float64, max_value::Float64) +""" function relu_cap(z::Complex, min_value::Float64, max_value::Float64) return min_value + (max_value - min_value) * relu(z - relu(z-1)) end +""" + sigmoid(z::Complex) +""" +function sigmoid(z::Complex) + return 1.0 / ( 1.0 + exp(-z) ) + # if real(z) > 0 + # return 1 / ( 1.0 + exp(-z) ) + # else + # return exp(z) / (1.0 + exp(z)) + # end +end + +""" + gelu(x::Complex) + +Extension of the GELU activation function for complex variables. +We use the approximation using tanh() to avoid dealing with the complex error function +""" +function gelu(z::Complex) + # We use the Gelu approximation to avoid complex holomorphic error function + return 0.5 * z * (1 + tanh((sqrt(2/π))*(z + 0.044715 * z^3))) +end + + +### Spherical Utils + """ cart2sph(X::AbstractArray{<:Number}; radians::Bool=true) @@ -133,6 +168,8 @@ function Base.:(+)(X::Array{F, 2}, ϵ::N) where {F <: AbstractFloat, N <: Abstra end end +### Numerics Utils + """ quadrature_integrate @@ -161,7 +198,7 @@ Simple central differences implementation. FiniteDifferences.jl does not work with AD so I implemented this manually. Still remains to test this with FiniteDiff.jl """ -function central_fdm(f::Function, x::Float64; ϵ=0.01) +function central_fdm(f::Function, x::Float64, ϵ::Float64) return (f(x+ϵ)-f(x-ϵ)) / (2ϵ) end @@ -170,10 +207,12 @@ end Manual implementation of complex-step differentiation """ -function complex_step_differentiation(f::Function, x::Float64; ϵ=1e-10) +function complex_step_differentiation(f::Function, x::Float64, ϵ::Float64) return imag(f(x + ϵ * im)) / ϵ end +### Other Utils + """ raise_warnings(data::AD, params::AP) From f4f6063f2f05f7c44d560da202c67118c91350d7 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Mon, 7 Oct 2024 17:58:22 -0700 Subject: [PATCH 26/29] Example working with Torsvik data --- examples/Torsvik_2012/APWP-Torsvik.jl | 44 +++++++++++++++----- src/train.jl | 58 +++++++++++++++------------ src/utils.jl | 4 +- 3 files changed, 70 insertions(+), 36 deletions(-) diff --git a/examples/Torsvik_2012/APWP-Torsvik.jl b/examples/Torsvik_2012/APWP-Torsvik.jl index 2cf7aab..7cb4db6 100644 --- a/examples/Torsvik_2012/APWP-Torsvik.jl +++ b/examples/Torsvik_2012/APWP-Torsvik.jl @@ -1,5 +1,6 @@ using Pkg; Pkg.activate(".") using Revise +using Lux using LinearAlgebra, Statistics, Distributions using SciMLSensitivity @@ -13,8 +14,8 @@ using Random rng = Random.default_rng() Random.seed!(rng, 613) - using DataFrames, CSV +using Serialization df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame, delim=",") @@ -51,17 +52,40 @@ data = SphereData(times=times, directions=X, kappas=kappas, L=nothing) tspan = [times[begin], times[end]] params = SphereParameters(tmin = tspan[1], tmax = tspan[2], - # reg = [Regularization(order=1, power=2.0, λ=1e2, diff_mode=ComplexStepDifferentiation())], - reg = nothing, + reg = [Regularization(order=1, power=2.0, λ=1e5, diff_mode=LuxNestedAD())], + # reg = nothing, pretrain = false, u0 = [0.0, 0.0, -1.0], ωmax = ω₀, - reltol = 1e-8, abstol = 1e-8, - niter_ADAM = 3000, niter_LBFGS = 8000, + reltol = 1e-6, abstol = 1e-6, + niter_ADAM = 2000, niter_LBFGS = 4000, sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) - - -results = train(data, params, rng, nothing) - -plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation") # , matplotlib_rcParams=Dict("font.size"=> 50)) +# Linear interpolation function +# currently not working + +train_model = true + +if train_model + + init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) + init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) + + # Customized neural network to similate weighted moving window in L + U = Lux.Chain( + Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), + Lux.Dense(200,10, gelu), + Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) + ) + + results = train(data, params, rng, nothing, U) + serialize("examples/Torsvik_2012/results.dat", Dict("data" => data, + "params" => params, + "results"=>results)) +else + # Read results + res = deserialize("resuls.dat") + results = res["results"] +end + +plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation") plot_L(data, results, saveas="examples/Torsvik_2012/plots/plot_L.pdf", title="Double rotation") diff --git a/src/train.jl b/src/train.jl index d7a72d3..d02322d 100644 --- a/src/train.jl +++ b/src/train.jl @@ -1,8 +1,8 @@ export train function get_NN(params, rng, θ_trained) - # Define neural network - + # Define default neural network + # For L1 regularization relu_cap works better, but for L2 I think is better to include sigmoid if isL1reg(params.reg) @warn "[SphereUDE] Using ReLU activation functions for neural network due to L1 regularization." @@ -23,9 +23,8 @@ function get_NN(params, rng, θ_trained) Lux.Dense(10, 5, gelu), Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax)) ) - end - θ, st = Lux.setup(rng, U) - return U, θ, st + end + return U end """ @@ -35,6 +34,7 @@ Predict value of rotation given by L given by the neural network. """ function predict_L(t, NN, θ, st) return NN([t], θ, st)[1] + # return 0.2 * ( NN([t-1.0], θ, st)[1] + NN([t-0.5], θ, st)[1] + NN([t], θ, st)[1] + NN([t+0.5], θ, st)[1] + NN([t+1.0], θ, st)[1] ) end """ @@ -45,12 +45,18 @@ Training function. function train(data::AD, params::AP, rng, - θ_trained=[]) where {AD <: AbstractData, AP <: AbstractParameters} + θ_trained=[], + model::Chain=nothing) where {AD <: AbstractData, AP <: AbstractParameters} # Raise warnings raise_warnings(data::AD, params::AP) - - U, θ, st = get_NN(params, rng, θ_trained) + + if isnothing(model) + U = get_NN(params, rng, θ_trained) + else + U = model + end + θ, st = Lux.setup(rng, U) # Make it a stateful layer (this I don't know where is best to add it, it repeats) smodel = StatefulLuxLayer{true}(U, θ, st) @@ -247,7 +253,8 @@ function train(data::AD, callback = function (p, l) push!(losses, l) if length(losses) % 20 == 0 - println("Current loss after $(length(losses)) iterations: $(losses[end])") + @printf "Iteration: [%5d / %5d] \t Loss: %.9f \n" length(losses) (params.niter_ADAM+params.niter_LBFGS) losses[end] + # println("Current loss after $(length(losses)) iterations: $(losses[end])") end if params.train_initial_condition p.u0 ./= norm(p.u0) @@ -266,22 +273,23 @@ function train(data::AD, """ Pretraining to find parameters without impossing regularization """ - # if params.pretrain - # losses_pretrain = Float64[] - # callback_pretrain = function(p, l) - # push!(losses_pretrain, l) - # if length(losses_pretrain) % 100 == 0 - # println("[Pretrain with no regularization] Current loss after $(length(losses_pretrain)) iterations: $(losses_pretrain[end])") - # end - # return false - # end - # optf₀ = Optimization.OptimizationFunction((x, β) -> loss_empirical(x), adtype) - # optprob₀ = Optimization.OptimizationProblem(optf₀, β) - # res₀ = Optimization.solve(optprob₀, ADAM(), callback=callback_pretrain, maxiters=params.niter_ADAM, verbose=false) - # optprob₁ = Optimization.OptimizationProblem(optf₀, res₀.u) - # res₁ = Optimization.solve(optprob₁, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback_pretrain, maxiters=params.niter_LBFGS) - # β = res₁.u - # end + if params.pretrain + losses_pretrain = Float64[] + callback_pretrain = function(p, l) + push!(losses_pretrain, l) + if length(losses_pretrain) % 100 == 0 + @printf "[Pretrain with no regularization] Iteration: [%5d / %5d] \t Loss: %.9f \n" length(losses_pretrain) (params.niter_ADAM+params.niter_LBFGS) losses_pretrain[end] + # println("[Pretrain with no regularization] Current loss after $(length(losses_pretrain)) iterations: $(losses_pretrain[end])") + end + return false + end + optf₀ = Optimization.OptimizationFunction((x, β) -> loss_empirical(x), adtype) + optprob₀ = Optimization.OptimizationProblem(optf₀, β) + res₀ = Optimization.solve(optprob₀, ADAM(), callback=callback_pretrain, maxiters=params.niter_ADAM, verbose=false) + optprob₁ = Optimization.OptimizationProblem(optf₀, res₀.u) + res₁ = Optimization.solve(optprob₁, Optim.BFGS(; initial_stepnorm=0.01, linesearch=LineSearches.BackTracking()), callback=callback_pretrain, maxiters=params.niter_LBFGS) + β = res₁.u + end # To do: implement this with polyoptimizaion to put ADAM and BFGS in one step. # Maybe better to keep like this for the line search. diff --git a/src/utils.jl b/src/utils.jl index 1aaf52c..4bcce72 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -1,6 +1,6 @@ export sigmoid, sigmoid_cap export relu, relu_cap -export gelu +export gelu, rbf export cart2sph, sph2cart export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation @@ -18,6 +18,8 @@ import Lux: relu, gelu Normalization of the neural network last layer """ +rbf(x) = exp.(-(x .^ 2)) + sigmoid_cap(x; ω₀=1.0) = sigmoid_cap(x, ω₀) function sigmoid_cap(x, ω₀) From ecfeef97a8bf0f205b37535cd0a828e59c8ff09c Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 11 Oct 2024 11:57:53 -0700 Subject: [PATCH 27/29] Improvements in desing of NN and Torskvik example --- Project.toml | 6 +++- examples/Torsvik_2012/APWP-Torsvik.jl | 52 ++++++++++++--------------- src/SphereUDE.jl | 5 +-- src/plot.jl | 2 +- src/train.jl | 24 +++++++++---- src/utils.jl | 15 ++++++++ 6 files changed, 61 insertions(+), 43 deletions(-) diff --git a/Project.toml b/Project.toml index 1d10b12..4a6b99f 100644 --- a/Project.toml +++ b/Project.toml @@ -12,18 +12,22 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0" DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0" Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838" +FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549" ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210" Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" +JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b2108857-7c20-44ae-9111-449ecde12c47" +Lux = "b210s8857-7c20-44ae-9111-449ecde12c47" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1" OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282" OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8" OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a" +Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80" PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d" +Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7" PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0" PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee" Revise = "295af30f-e4ad-537b-8983-00126c2a3abe" diff --git a/examples/Torsvik_2012/APWP-Torsvik.jl b/examples/Torsvik_2012/APWP-Torsvik.jl index 7cb4db6..dbb30fb 100644 --- a/examples/Torsvik_2012/APWP-Torsvik.jl +++ b/examples/Torsvik_2012/APWP-Torsvik.jl @@ -15,7 +15,7 @@ rng = Random.default_rng() Random.seed!(rng, 613) using DataFrames, CSV -using Serialization +using Serialization, JLD2 df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame, delim=",") @@ -23,7 +23,7 @@ df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame Gondwana = ["Amazonia", "Parana", "Colorado", "Southern_Africa", "East_Antarctica", "Madagascar", "Patagonia", "Northeast_Africa", - "Northwest_Africa", "Somalia", "Arabia"] + "Northwest_Africa", "Somalia", "Arabia", "East_Gondwana"] df = filter(row -> row.Plate ∈ Gondwana, df) df.Times = df.Age .+= rand(sampler(Normal(0,0.1)), nrow(df)) # Needs to fix this! @@ -52,40 +52,32 @@ data = SphereData(times=times, directions=X, kappas=kappas, L=nothing) tspan = [times[begin], times[end]] params = SphereParameters(tmin = tspan[1], tmax = tspan[2], - reg = [Regularization(order=1, power=2.0, λ=1e5, diff_mode=LuxNestedAD())], + reg = [Regularization(order=1, power=2.0, λ=1e5, diff_mode=FiniteDifferences(1e-4))], # reg = nothing, pretrain = false, u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = 1e-6, abstol = 1e-6, - niter_ADAM = 2000, niter_LBFGS = 4000, + niter_ADAM = 5000, niter_LBFGS = 5000, sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true))) -# Linear interpolation function -# currently not working - -train_model = true - -if train_model - - init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) - init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) - - # Customized neural network to similate weighted moving window in L - U = Lux.Chain( - Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), - Lux.Dense(200,10, gelu), - Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) - ) - - results = train(data, params, rng, nothing, U) - serialize("examples/Torsvik_2012/results.dat", Dict("data" => data, - "params" => params, - "results"=>results)) -else - # Read results - res = deserialize("resuls.dat") - results = res["results"] -end + +init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) +init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) + +# Customized neural network to similate weighted moving window in L +U = Lux.Chain( + Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), + Lux.Dense(200,10, gelu), + Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) +) + +results = train(data, params, rng, nothing, U) +results_dict = convert2dict(data, results) + +# JLD2.@save "examples/Torsvik_2012/results/data.jld2" data +# JLD2.@save "examples/Torsvik_2012/results/results.jld2" results +JLD2.@save "examples/Torsvik_2012/results/results_dict.jld2" results_dict + plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation") plot_L(data, results, saveas="examples/Torsvik_2012/plots/plot_L.pdf", title="Double rotation") diff --git a/src/SphereUDE.jl b/src/SphereUDE.jl index 43621df..6003420 100644 --- a/src/SphereUDE.jl +++ b/src/SphereUDE.jl @@ -13,10 +13,7 @@ using Optimization, OptimizationOptimisers, OptimizationOptimJL using OptimizationPolyalgorithms, LineSearches using ComponentArrays using PyPlot, PyCall -using PrettyTables - -# Testing double-differentiation -# using BatchedRoutines +using PrettyTables, Printf # Debugging using Infiltrator diff --git a/src/plot.jl b/src/plot.jl index c33df4e..ddfa06d 100644 --- a/src/plot.jl +++ b/src/plot.jl @@ -40,7 +40,7 @@ function plot_sphere(# ax::PyCall.PyObject, end end - # ax.coastlines() + ax.coastlines() ax.gridlines() ax.set_global() diff --git a/src/train.jl b/src/train.jl index d02322d..23f120e 100644 --- a/src/train.jl +++ b/src/train.jl @@ -46,7 +46,7 @@ function train(data::AD, params::AP, rng, θ_trained=[], - model::Chain=nothing) where {AD <: AbstractData, AP <: AbstractParameters} + model::Union{Chain, Nothing}=nothing) where {AD <: AbstractData, AP <: AbstractParameters} # Raise warnings raise_warnings(data::AD, params::AP) @@ -171,10 +171,18 @@ function train(data::AD, end # Compute the final agregation to the loss - for j in 1:n_nodes - l_ += weights[j] * norm(Jac[:,1,j])^reg.power - end + l_ += sum([weights[j] * norm(Jac[:,1,j])^reg.power for j in 1:n_nodes]) + # Test every a few iterations that AD is working properly + if rand(Bernoulli(0.001)) + l_AD = sum([weights[j] * norm(Jac[:,1,j])^reg.power for j in 1:n_nodes]) + l_FD = quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, 1e-5))^reg.power, params.tmin, params.tmax, n_nodes) + if abs(l_AD - l_FD) < 1e-2 * l_FD + @warn "[SphereUDE] Nested AD is giving significant different results than Finite Differences." + @printf "[SphereUDE] Regularization with AD: %.9f vs %.9f using Finite Differences" l_AD l_FD + end + end + elseif typeof(reg.diff_mode) <: FiniteDifferences # Finite differences l_ += quadrature(t -> norm(central_fdm(τ -> predict_L(τ, U, θ, st), t, reg.diff_mode.ϵ))^reg.power, params.tmin, params.tmax, n_nodes) @@ -252,7 +260,7 @@ function train(data::AD, losses = Float64[] callback = function (p, l) push!(losses, l) - if length(losses) % 20 == 0 + if length(losses) % 50 == 0 @printf "Iteration: [%5d / %5d] \t Loss: %.9f \n" length(losses) (params.niter_ADAM+params.niter_LBFGS) losses[end] # println("Current loss after $(length(losses)) iterations: $(losses[end])") end @@ -321,13 +329,15 @@ function train(data::AD, end # Final Fit - fit_times = collect(range(params.tmin,params.tmax, length=length(data.times))) + fit_times = collect(range(params.tmin,params.tmax, length=1000)) fit_directions, _ = predict(β_trained, T=fit_times) + fit_rotations = reduce(hcat, (t -> U([t], θ_trained, st)[1]).(fit_times)) # Recover final balance between different terms involved in the loss function to assess hyperparameter selection. _, loss_dict = loss(β_trained) pretty_table(loss_dict, sortkeys=true, header=["Loss term", "Value"]) return Results(θ=θ_trained, u0=u0_trained, U=U, st=st, - fit_times=fit_times, fit_directions=fit_directions) + fit_times=fit_times, fit_directions=fit_directions, + fit_rotations=fit_rotations, losses=losses) end diff --git a/src/utils.jl b/src/utils.jl index 4bcce72..47b86bc 100644 --- a/src/utils.jl +++ b/src/utils.jl @@ -6,6 +6,7 @@ export AbstractNoise, FisherNoise export quadrature, central_fdm, complex_step_differentiation export raise_warnings export isL1reg +export convert2dict # Import activation function for complex extension import Lux: relu, gelu @@ -248,4 +249,18 @@ function isL1reg(regs::Union{Vector{R}, Nothing}) where {R <: AbstractRegulariza end end return false +end + +function convert2dict(data::SphereData, results::Results) + _dict = Dict() + _dict["times"] = data.times + _dict["directions"] = data.directions + _dict["kappas"] = data.kappas + _dict["u0"] = results.u0 + _dict["fit_times"] = results.fit_times + _dict["fit_directions"] = results.fit_directions + _dict["fit_rotations"] = results.fit_rotations + _dict["losses"] = results.losses + + return _dict end \ No newline at end of file From 296ee0486e188ccf7d021b5703ef690b5b71bf88 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 11 Oct 2024 12:02:26 -0700 Subject: [PATCH 28/29] small changes in types --- examples/benchmark.jl | 1 + examples/double_rotation/double_rotation.jl | 99 ++++++++++++++++----- src/types.jl | 2 + 3 files changed, 81 insertions(+), 21 deletions(-) diff --git a/examples/benchmark.jl b/examples/benchmark.jl index e375095..d224657 100644 --- a/examples/benchmark.jl +++ b/examples/benchmark.jl @@ -45,6 +45,7 @@ function benchmark() #solvers = [BS5(), OwrenZen5(), OwrenZen3(), BS3(), Tsit5()] solvers = [BS5(), Tsit5()] + # Different sensealgs available at: https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))] #sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)), InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))] diff --git a/examples/double_rotation/double_rotation.jl b/examples/double_rotation/double_rotation.jl index 4890388..0e4c5cb 100644 --- a/examples/double_rotation/double_rotation.jl +++ b/examples/double_rotation/double_rotation.jl @@ -5,6 +5,7 @@ using LinearAlgebra, Statistics, Distributions using SciMLSensitivity using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5 using Optimization, OptimizationOptimisers, OptimizationOptimJL +using Lux using SphereUDE @@ -37,8 +38,8 @@ L0 = ω₀ .* [1.0, 0.0, 0.0] L1 = 0.6ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)] # Solver tolerances -reltol = 1e-12 -abstol = 1e-12 +reltol = 1e-6 +abstol = 1e-6 function L_true(t::Float64; τ₀=τ₀, p=[L0, L1]) if t < τ₀ @@ -71,10 +72,19 @@ params = SphereParameters(tmin = tspan[1], tmax = tspan[2], train_initial_condition = false, multiple_shooting = false, u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol, - niter_ADAM = 2000, niter_LBFGS = 2000, + niter_ADAM = 5000, niter_LBFGS = 5000, sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true))) -results = train(data, params, rng, nothing) +init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims) +init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims) + +U = Lux.Chain( + Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true), + Lux.Dense(200,10, gelu), + Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false) +) + +results = train(data, params, rng, nothing, U) ############################################################## ###################### PyCall Plots ######################### @@ -87,28 +97,75 @@ end # run # Run different experiments -ϵ = 1e-5 -λ₀ = 0.1 -λ₁ = 0.01 +### Finite differeces + +# run(; kappa = 50., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=10.0)], +# title = "plots/FD_plot_50") + +# run(; kappa = 200., +# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/FD_plot_200") + + +# run(; kappa = 1000., +# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/FD_plot_1000") + + +# Complex Step Method + +# run(; kappa = 50., +# regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_50") + +# run(; kappa = 200., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_200") -run(; kappa = 50., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), - Regularization(order=0, power=2.0, λ=λ₀)], - title = "plots/plot_50_lambda$(λ₁)") -λ₀ = 0.1 -λ₁ = 0.1 +# run(; kappa = 1000., +# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)), +# Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/CS_plot_1000") + + + +### AD + +run(; kappa = 50., + regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=LuxNestedAD())], %, + # Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_50") run(; kappa = 200., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), - Regularization(order=0, power=2.0, λ=λ₀)], - title = "plots/plot_200_lambda$(λ₁)") + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()), + Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_200") + +run(; kappa = 1000., + regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()), + Regularization(order=0, power=2.0, λ=0.1)], + title = "plots/AD_plot_1000") + + +### no first-order regularization + +# run(; kappa = 50., +# regs = [Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/None_plot_50") + -λ₀ = 0.1 -λ₁ = 0.1 +# run(; kappa = 200., +# regs = [Regularization(order=0, power=2.0, λ=0.1)], +# title = "plots/None_plot_200") run(; kappa = 1000., - regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode=FiniteDifferences(ϵ)), - Regularization(order=0, power=2.0, λ=λ₀)], - title = "plots/plot_1000_lambda$(λ₁)") \ No newline at end of file + regs = nothing, + title = "plots/_None_plot_1000") \ No newline at end of file diff --git a/src/types.jl b/src/types.jl index 7cc9ce2..5d5ed8c 100644 --- a/src/types.jl +++ b/src/types.jl @@ -50,6 +50,8 @@ Final results st::NamedTuple fit_times::Vector{F} fit_directions::Matrix{F} + fit_rotations::Matrix{F} + losses::Vector{F} end """ From 368278e1b714c5adcd8df0f03a35410511912580 Mon Sep 17 00:00:00 2001 From: Facundo Sapienza Date: Fri, 11 Oct 2024 15:16:42 -0700 Subject: [PATCH 29/29] Fix typo in Lux deps --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 4a6b99f..98d68a1 100644 --- a/Project.toml +++ b/Project.toml @@ -18,7 +18,7 @@ Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b" JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819" LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" -Lux = "b210s8857-7c20-44ae-9111-449ecde12c47" +Lux = "b2108857-7c20-44ae-9111-449ecde12c47" Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba" OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e" OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"