Skip to content

Commit

Permalink
feat: update to support Lux 1.0
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Sep 21, 2024
1 parent 05314e2 commit e2116c3
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 10 deletions.
13 changes: 8 additions & 5 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SphereUDE"
uuid = "d7416ba7-148a-4110-b27d-9087fcebab2d"
authors = ["Facundo Sapienza <[email protected]>", "Jordi Bolibar <[email protected]>"]
version = "0.1.1"
version = "0.1.2"

[deps]
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
Expand All @@ -19,32 +19,35 @@ Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
OptimizationOptimJL = "36348300-93cb-4f02-beb5-3c3902f8871e"
OptimizationOptimisers = "42dfb2eb-d2b4-4451-abcd-913932933ac1"
OptimizationPolyalgorithms = "500b13db-7e66-49ce-bda4-eed966be6282"
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
OrdinaryDiffEqCore = "bbf590c4-e513-4bbe-9b18-05decba2e5d8"
OrdinaryDiffEqTsit5 = "b1df2697-797e-41e3-8120-5422d3b24e4a"
PrettyTables = "08abe8d2-0d0c-5749-adfa-8a2ac140af0d"
PyCall = "438e738f-606a-5dbb-bf0a-cddfbfd45ab0"
PyPlot = "d330b81b-6aea-500a-939a-2ce795aea3ee"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
SciMLSensitivity = "1ed8b502-d754-442c-8d5d-10ac956f44a1"
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
TruncatedStacktraces = "781d530d-4396-4725-bb49-402e4bee1e77"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[compat]
BenchmarkTools = "1"
ComponentArrays = "0.15"
DiffEqFlux = "4"
Distributions = "0.25"
Infiltrator = "1.2"
Lux = "1"
Optimization = "3.12"
OptimizationOptimJL = "0.1.5"
OptimizationOptimisers = "0.1.2"
OrdinaryDiffEq = "5, 6"
OrdinaryDiffEqCore = "1.6.0"
OrdinaryDiffEqTsit5 = "1.1.0"
PyCall = "1.9"
PyPlot = "2.11"
Revise = "3.1"
SciMLSensitivity = "7.20"
Statistics = "1"
Zygote = "0.6"
julia = "1.7"
julia = "1.10"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Expand Down
4 changes: 1 addition & 3 deletions src/SphereUDE.jl
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
__precompile__()
module SphereUDE

# types
using Base: @kwdef
# utils
# training
using LinearAlgebra, Statistics, Distributions
using FastGaussQuadrature
using Lux, Zygote, DiffEqFlux
using ChainRules: @ignore_derivatives
using OrdinaryDiffEq
using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL, OptimizationPolyalgorithms
using ComponentArrays
Expand Down
4 changes: 2 additions & 2 deletions src/types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ Training parameters
niter_LBFGS::I
reltol::F
abstol::F
solver::OrdinaryDiffEq.OrdinaryDiffEqAlgorithm = Tsit5()
solver::OrdinaryDiffEqCore.OrdinaryDiffEqAlgorithm = Tsit5()
sensealg::SciMLBase.AbstractAdjointSensitivityAlgorithm = QuadratureAdjoint(autojacvec=ReverseDiffVJP(true))
end

Expand All @@ -43,7 +43,7 @@ Final results
@kwdef struct Results{F <: AbstractFloat} <: AbstractResult
θ::ComponentVector
u0::Vector{F}
U::Lux.Chain
U::Lux.AbstractLuxLayer
st::NamedTuple
fit_times::Vector{F}
fit_directions::Matrix{F}
Expand Down

0 comments on commit e2116c3

Please sign in to comment.