Skip to content

Commit

Permalink
feat: update to support Lux 1.0 (#94)
Browse files Browse the repository at this point in the history
* feat: update to support Lux 1.0

* ci: up to 1.10

* fix: make the activations more Lux friendly
  • Loading branch information
avik-pal authored Sep 28, 2024
1 parent 003ef37 commit fce5757
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 8 deletions.
6 changes: 3 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,15 @@ function get_NN(params, rng, θ_trained)
# Lux.Dense(1, 5, relu_cap),
# Lux.Dense(5, 10, relu_cap),
# Lux.Dense(10, 5, relu_cap),
Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax))
# Lux.Dense(5, 3, x -> relu_cap(x; ω₀=params.ωmax))
Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax))
# Lux.Dense(5, 3, Base.Fix2(relu_cap, params.ωmax))
)
else
U = Lux.Chain(
Lux.Dense(1, 5, sigmoid),
Lux.Dense(5, 10, sigmoid),
Lux.Dense(10, 5, sigmoid),
Lux.Dense(5, 3, x -> sigmoid_cap(x; ω₀=params.ωmax))
Lux.Dense(5, 3, Base.Fix2(sigmoid_cap, params.ωmax))
)
end
θ, st = Lux.setup(rng, U)
Expand Down
8 changes: 6 additions & 2 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ export isL1reg
Normalization of the neural network last layer
"""
function sigmoid_cap(x; ω₀=1.0)
sigmoid_cap(x; ω₀=1.0) = sigmoid_cap(x, ω₀)

function sigmoid_cap(x, ω₀)
min_value = - ω₀
max_value = + ω₀
return min_value + (max_value - min_value) * sigmoid(x)
Expand All @@ -38,7 +40,9 @@ end
"""
relu_cap(x; ω₀=1.0)
"""
function relu_cap(x; ω₀=1.0)
relu_cap(x; ω₀=1.0) = relu_cap(x, ω₀)

function relu_cap(x, ω₀)
min_value = - ω₀
max_value = + ω₀
return relu_cap(x, min_value, max_value)
Expand Down
3 changes: 1 addition & 2 deletions test/rotation.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LinearAlgebra, Statistics, Distributions
using OrdinaryDiffEqCore, OrdinaryDiffEqTsit5
using LinearAlgebra, Statistics, Distributions
using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL

Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -19,4 +19,4 @@ end

@testset "Inversion" begin
test_single_rotation()
end
end

0 comments on commit fce5757

Please sign in to comment.