Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating Lux with Nested AD #111

Merged
merged 36 commits into from
Oct 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
36 commits
Select commit Hold shift + click to select a range
8f88211
Update Project with new dependencies
facusapienza21 Apr 10, 2024
ad54394
remove FiniteDifferences from dependencies
facusapienza21 Apr 17, 2024
a4790d3
Example of APWP fit based on Jupp1987
facusapienza21 Apr 17, 2024
29b8c78
add complex-step method
facusapienza21 Apr 19, 2024
a3852d1
Double differentiation working with complex-step method
facusapienza21 Apr 19, 2024
cc5ce31
Initial condition u0 fitting implemented
facusapienza21 Apr 19, 2024
92f00d6
Projected gradient descent working for u0. Some more tests
facusapienza21 Apr 20, 2024
c69a059
Testing activation functions with complex-step
facusapienza21 Apr 29, 2024
6ddd80a
Merge branch 'main' into main
facusapienza21 Apr 29, 2024
ba41e79
Merge branch 'ODINN-SciML:main' into main
facusapienza21 Apr 30, 2024
3e30160
predict function, return multiple losses
facusapienza21 Apr 30, 2024
9413263
Merge branch 'main' of https://github.com/facusapienza21/SphereUDE.jl…
facusapienza21 Apr 30, 2024
10ae8cd
Co-authored-by: Jordi Bolibar <[email protected]>
facusapienza21 Apr 30, 2024
684621f
Merge branch 'main' of https://github.com/facusapienza21/SphereUDE.jl…
facusapienza21 May 3, 2024
3d52691
Multiple shooting working once sensealg specified
facusapienza21 May 3, 2024
9fa3354
Double rotation example with small changes in src
facusapienza21 Jul 20, 2024
e2116c3
feat: update to support Lux 1.0
avik-pal Sep 21, 2024
f4eee7a
Example with double rotation working with non-updated Lux
facusapienza21 Sep 26, 2024
8a01ca5
Integration test of inversion
facusapienza21 Sep 26, 2024
8ee17ed
Added Random as test dependency
facusapienza21 Sep 26, 2024
b92e58a
Merge branch 'main' of https://github.com/facusapienza21/SphereUDE.jl…
facusapienza21 Sep 27, 2024
12a692d
bring changes from @avik-pal branch
facusapienza21 Sep 27, 2024
aad95b4
Fix Lux version
facusapienza21 Sep 27, 2024
9711584
[WIP] Working around `Lux=1` (#108)
facusapienza21 Sep 27, 2024
d1a1fcd
CI on `up-lux` branch
facusapienza21 Sep 27, 2024
62848d9
Update CI.yml - Update CI version
facusapienza21 Sep 27, 2024
003ef37
Remove OrdinaryDiffEq from test
facusapienza21 Sep 28, 2024
fce5757
feat: update to support Lux 1.0 (#94)
avik-pal Sep 28, 2024
e7ddd95
Merge branch 'up-lux' of https://github.com/ODINN-SciML/SphereUDE.jl …
facusapienza21 Sep 29, 2024
a481959
Define abstract types for different differentiation modes
facusapienza21 Sep 30, 2024
a3a9348
Implemented regularization with Lux nested AD
facusapienza21 Oct 1, 2024
c6d33df
[WIP] Torsvik + working on complex-step with Lux
facusapienza21 Oct 3, 2024
f4f6063
Example working with Torsvik data
facusapienza21 Oct 8, 2024
ecfeef9
Improvements in desing of NN and Torskvik example
facusapienza21 Oct 11, 2024
296ee04
small changes in types
facusapienza21 Oct 11, 2024
368278e
Fix typo in Lux deps
facusapienza21 Oct 11, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@ on:
push:
branches:
- main
- up-lux
tags: ['*']
pull_request:
branches:
- main
- up-lux
workflow_dispatch:
concurrency:
# Skip intermediate builds: always.
Expand All @@ -24,7 +26,7 @@ jobs:
fail-fast: false
matrix:
version:
- '1.9'
- '1'
# - 'nightly'
python:
- 3.9
Expand Down
22 changes: 15 additions & 7 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SphereUDE"
uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d"
authors = ["Facundo Sapienza <[email protected]>", "Jordi Bolibar <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Expand All @@ -12,44 +12,52 @@ DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
DiffEqFlux = "aae7a2af-3d4f-5e19-a356-7da93b79d9d0"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
FastGaussQuadrature = "442a2c76-b920-505d-bb47-c5924d526838"
FileIO = "5789e2e9-d7fb-5bc7-8068-2c6fae9b9549"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
Infiltrator = "5903a43b-9cc3-4c30-8d17-598619ec4e9b"
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
LineSearches = "d3d80556-e9d4-5f37-9878-2ab0fcc64255"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BenchmarkTools = "1"
ComponentArrays = "0.15"
DiffEqFlux = "4"
Distributions = "0.25"
Infiltrator = "1.2"
Lux = "<0.5.49"
Lux = "1.0"
Optimization = "3.12"
OptimizationOptimJL = "0.1.5"
OptimizationOptimisers = "0.1.2"
OrdinaryDiffEq = "5, 6"
OrdinaryDiffEqCore = "1.6.0"
OrdinaryDiffEqTsit5 = "1.1.0"
PyCall = "1.9"
PyPlot = "2.11"
Revise = "3.1"
SciMLSensitivity = "7.20"
Statistics = "1"
Zygote = "0.6"
julia = "1.7"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["Test", "Random"]
83 changes: 83 additions & 0 deletions examples/Torsvik_2012/APWP-Torsvik.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
using Pkg; Pkg.activate(".")
using Revise
using Lux

using LinearAlgebra, Statistics, Distributions
using SciMLSensitivity
# using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5
using Optimization, OptimizationOptimisers, OptimizationOptimJL

using SphereUDE

# Random seed
using Random
rng = Random.default_rng()
Random.seed!(rng, 613)

using DataFrames, CSV
using Serialization, JLD2

df = CSV.read("./examples/Torsvik_2012/Torsvik-etal-2012_dataset.csv", DataFrame, delim=",")

# Filter the plates that were once part of the supercontinent Gondwana

Gondwana = ["Amazonia", "Parana", "Colorado", "Southern_Africa",
"East_Antarctica", "Madagascar", "Patagonia", "Northeast_Africa",
"Northwest_Africa", "Somalia", "Arabia", "East_Gondwana"]

df = filter(row -> row.Plate ∈ Gondwana, df)
df.Times = df.Age .+= rand(sampler(Normal(0,0.1)), nrow(df)) # Needs to fix this!

df = sort(df, :Times)
times = df.Times

# Fill missing values
df.RLat .= coalesce.(df.RLat, df.Lat)
df.RLon .= coalesce.(df.RLon, df.Lon)

X = sph2cart(Matrix(df[:,["RLat","RLon"]])'; radians=false)

# Retrieve uncertanties from poles and convert α95 into κ
kappas = (140.0 ./ df.a95).^2

data = SphereData(times=times, directions=X, kappas=kappas, L=nothing)

# Training

# Expected maximum angular deviation in one unit of time (degrees)
Δω₀ = 1.5
# Angular velocity
ω₀ = Δω₀ * π / 180.0

tspan = [times[begin], times[end]]

params = SphereParameters(tmin = tspan[1], tmax = tspan[2],
reg = [Regularization(order=1, power=2.0, λ=1e5, diff_mode=FiniteDifferences(1e-4))],
# reg = nothing,
pretrain = false,
u0 = [0.0, 0.0, -1.0], ωmax = ω₀,
reltol = 1e-6, abstol = 1e-6,
niter_ADAM = 5000, niter_LBFGS = 5000,
sensealg = InterpolatingAdjoint(autojacvec = ReverseDiffVJP(true)))


init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims)
init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims)

# Customized neural network to similate weighted moving window in L
U = Lux.Chain(
Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true),
Lux.Dense(200,10, gelu),
Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false)
)

results = train(data, params, rng, nothing, U)
results_dict = convert2dict(data, results)

# JLD2.@save "examples/Torsvik_2012/results/data.jld2" data
# JLD2.@save "examples/Torsvik_2012/results/results.jld2" results
JLD2.@save "examples/Torsvik_2012/results/results_dict.jld2" results_dict


plot_sphere(data, results, -30., 0., saveas="examples/Torsvik_2012/plots/plot_sphere.pdf", title="Double rotation")
plot_L(data, results, saveas="examples/Torsvik_2012/plots/plot_L.pdf", title="Double rotation")
1 change: 1 addition & 0 deletions examples/benchmark.jl
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ function benchmark()

#solvers = [BS5(), OwrenZen5(), OwrenZen3(), BS3(), Tsit5()]
solvers = [BS5(), Tsit5()]
# Different sensealgs available at: https://docs.sciml.ai/SciMLSensitivity/stable/manual/differential_equation_sensitivities
sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))]
#sensealgs = [QuadratureAdjoint(autojacvec=ReverseDiffVJP(true)), InterpolatingAdjoint(autojacvec=ReverseDiffVJP(true))]

Expand Down
101 changes: 80 additions & 21 deletions examples/double_rotation/double_rotation.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@ using Pkg; Pkg.activate(".")
using Revise

using LinearAlgebra, Statistics, Distributions
using OrdinaryDiffEq
using SciMLSensitivity
using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using Lux

using SphereUDE

Expand Down Expand Up @@ -37,8 +38,8 @@ L0 = ω₀ .* [1.0, 0.0, 0.0]
L1 = 0.6ω₀ .* [0.0, 1/sqrt(2), 1/sqrt(2)]

# Solver tolerances
reltol = 1e-12
abstol = 1e-12
reltol = 1e-6
abstol = 1e-6

function L_true(t::Float64; τ₀=τ₀, p=[L0, L1])
if t < τ₀
Expand Down Expand Up @@ -71,10 +72,19 @@ params = SphereParameters(tmin = tspan[1], tmax = tspan[2],
train_initial_condition = false,
multiple_shooting = false,
u0 = [0.0, 0.0, -1.0], ωmax = ω₀, reltol = reltol, abstol = abstol,
niter_ADAM = 2000, niter_LBFGS = 1000,
niter_ADAM = 5000, niter_LBFGS = 5000,
sensealg = GaussAdjoint(autojacvec = ReverseDiffVJP(true)))

results = train(data, params, rng, nothing)
init_bias(rng, in_dims) = LinRange(tspan[1], tspan[2], in_dims)
init_weight(rng, out_dims, in_dims) = 0.1 * ones(out_dims, in_dims)

U = Lux.Chain(
Lux.Dense(1, 200, rbf, init_bias=init_bias, init_weight=init_weight, use_bias=true),
Lux.Dense(200,10, gelu),
Lux.Dense(10, 3, Base.Fix2(sigmoid_cap, params.ωmax), use_bias=false)
)

results = train(data, params, rng, nothing, U)

##############################################################
###################### PyCall Plots #########################
Expand All @@ -87,26 +97,75 @@ end # run

# Run different experiments

λ₀ = 0.1
λ₁ = 0.001

run(; kappa = 50.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="FD"),
Regularization(order=0, power=2.0, λ=λ₀, diff_mode=nothing)],
title = "plots/plot_50_lambda$(λ₁)")
### Finite differeces

# run(; kappa = 50.,
# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=FiniteDifferences(1e-5)),
# Regularization(order=0, power=2.0, λ=10.0)],
# title = "plots/FD_plot_50")

# run(; kappa = 200.,
# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)),
# Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/FD_plot_200")


# run(; kappa = 1000.,
# regs = [Regularization(order=1, power=1.0, λ=1.0, diff_mode=FiniteDifferences(1e-5)),
# Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/FD_plot_1000")


# Complex Step Method

# run(; kappa = 50.,
# regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=ComplexStepDifferentiation(1e-5)),
# Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/CS_plot_50")

# run(; kappa = 200.,
# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)),
# Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/CS_plot_200")


λ₀ = 0.1
λ₁ = 0.1
# run(; kappa = 1000.,
# regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=ComplexStepDifferentiation(1e-5)),
# Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/CS_plot_1000")



### AD

run(; kappa = 50.,
regs = [Regularization(order=1, power=1.0, λ=0.01, diff_mode=LuxNestedAD())], %,
# Regularization(order=0, power=2.0, λ=0.1)],
title = "plots/AD_plot_50")

run(; kappa = 200.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=λ₀)],
title = "plots/plot_200_lambda$(λ₁)")
regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()),
Regularization(order=0, power=2.0, λ=0.1)],
title = "plots/AD_plot_200")

run(; kappa = 1000.,
regs = [Regularization(order=1, power=1.0, λ=0.1, diff_mode=LuxNestedAD()),
Regularization(order=0, power=2.0, λ=0.1)],
title = "plots/AD_plot_1000")


### no first-order regularization

# run(; kappa = 50.,
# regs = [Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/None_plot_50")


λ₀ = 0.1
λ₁ = 0.1
# run(; kappa = 200.,
# regs = [Regularization(order=0, power=2.0, λ=0.1)],
# title = "plots/None_plot_200")

run(; kappa = 1000.,
regs = [Regularization(order=1, power=1.0, λ=λ₁, diff_mode="CS"),
Regularization(order=0, power=2.0, λ=λ₀)],
title = "plots/plot_1000_lambda$(λ₁)")
regs = nothing,
title = "plots/_None_plot_1000")
14 changes: 5 additions & 9 deletions src/SphereUDE.jl
Original file line number Diff line number Diff line change
@@ -1,23 +1,19 @@
__precompile__()
module SphereUDE

# types
using Base: @kwdef
# utils
# training
using LinearAlgebra, Statistics, Distributions
using FastGaussQuadrature
using Lux, Zygote, DiffEqFlux
using ChainRules: @ignore_derivatives
using OrdinaryDiffEq
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms
using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5
using SciMLSensitivity, ForwardDiff
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using OptimizationPolyalgorithms, LineSearches
using ComponentArrays
using PyPlot, PyCall
using PrettyTables

# Testing double-differentiation
# using BatchedRoutines
using PrettyTables, Printf

# Debugging
using Infiltrator
Expand Down
2 changes: 1 addition & 1 deletion src/plot.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ function plot_sphere(# ax::PyCall.PyObject,
end
end

# ax.coastlines()
ax.coastlines()
ax.gridlines()
ax.set_global()

Expand Down
Loading
Loading