Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding test of GPU support #34

Merged
merged 15 commits into from
Feb 5, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ForwardDiff = "0.10"
JuMP = "1"
Lux = "1"
LuxCore = "1.1.0"
LuxCUDA = "0.3.3"
ModelingToolkit = "9.51"
NLopt = "1"
NeuralPDE = "5.17"
Expand All @@ -34,6 +35,7 @@ Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
CSDP = "0a46da34-8e4b-519e-b418-48813639ff34"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -44,4 +46,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SafeTestsets", "Test", "Lux", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "NLopt", "Random", "NeuralPDE", "CSDP", "Boltz", "ComponentArrays"]
test = ["SafeTestsets", "Test", "Lux", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "NLopt", "Random", "NeuralPDE", "CSDP", "Boltz", "ComponentArrays", "LuxCUDA"]
154 changes: 154 additions & 0 deletions test/damped_sho_CUDA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
using NeuralPDE, NeuralLyapunov
import Optimization, OptimizationOptimisers, OptimizationOptimJL
using Random
using Lux, LuxCUDA, ComponentArrays
using Test, LinearAlgebra, ForwardDiff

Random.seed!(200)

println("Damped Simple Harmonic Oscillator")

######################### Define dynamics and domain ##########################

"Simple Harmonic Oscillator Dynamics"
function f(state, p, t)
pos = state[1]
vel = state[2]
vcat(vel, -vel - pos)
end
lb = [-2.0, -2.0];
ub = [2.0, 2.0];
fixed_point = [0.0, 0.0];
dynamics = ODEFunction(f; sys = SciMLBase.SymbolCache([:x, :v]))

####################### Specify neural Lyapunov problem #######################

# Define neural network discretization
dim_state = length(lb)
dim_hidden = 20
chain = Chain(
Dense(dim_state, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1)
)
const gpud = gpu_device()
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f32

# Define training strategy
strategy = QuasiRandomTraining(2500)
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

# Define neural Lyapunov structure
structure = UnstructuredNeuralLyapunov()
minimization_condition = StrictlyPositiveDefinite(C = 0.1)

# Define Lyapunov decrease condition
# This damped SHO has exponential decrease at a rate of k = 0.5, so we train to certify that
decrease_condition = ExponentialStability(0.5)

# Construct neural Lyapunov specification
spec = NeuralLyapunovSpecification(
structure,
minimization_condition,
decrease_condition
)

############################# Construct PDESystem #############################

@named pde_system = NeuralLyapunovPDESystem(
dynamics,
lb,
ub,
spec;
)

######################## Construct OptimizationProblem ########################

prob = discretize(pde_system, discretization)
sym_prob = symbolic_discretize(pde_system, discretization)

########################## Solve OptimizationProblem ##########################

res = Optimization.solve(prob, OptimizationOptimisers.Adam(0.01); maxiters = 300)
prob = Optimization.remake(prob, u0 = res.u)
res = Optimization.solve(prob, OptimizationOptimisers.Adam(); maxiters = 300)
prob = Optimization.remake(prob, u0 = res.u)
res = Optimization.solve(prob, OptimizationOptimJL.BFGS(); maxiters = 300)

###################### Get numerical numerical functions ######################
V, V̇ = get_numerical_lyapunov_function(
discretization.phi,
(; φ1 = res.u),
structure,
f,
fixed_point
)

################################## Simulate ###################################
Δx = (ub[1] - lb[1]) / 100
Δv = (ub[2] - lb[2]) / 100
xs = lb[1]:Δx:ub[1]
vs = lb[2]:Δv:ub[2]
states = Iterators.map(collect, Iterators.product(xs, vs))
V_samples_gpu = vec(V(hcat(states...)))
V̇_samples_gpu = vec((hcat(states...)))

const cpud = cpu_device()
V_samples = V_samples_gpu |> cpud
V̇_samples = V̇_samples_gpu |> cpud

#################################### Tests ####################################

# Network structure should enforce nonegativeness of V
V0 = (V(fixed_point) |> cpud)[]
V_min, i_min = findmin(V_samples)
state_min = collect(states)[i_min]
V_min, state_min = if V0 V_min
V0, fixed_point
else
V_min, state_min
end
@test V_min -1e-2

# Trained for V's minimum to be near the fixed point
@test all(abs.(state_min .- fixed_point) .≤ 10 * [Δx, Δv])

# Check local negative semidefiniteness of V̇ at fixed point
@test ((fixed_point) |> cpud)[] == 0.0
@test maximum(abs, ForwardDiff.gradient(x -> ((x) |> cpud)[], fixed_point)) < 0.1
@test_broken maximum(eigvals(ForwardDiff.hessian(x -> ((x) |> cpud)[], fixed_point))) 0.0

# V̇ should be negative almost everywhere
@test sum(V̇_samples .> 0) / length(V̇_samples) < 5e-3

#=
# Print statistics
println("V(0.,0.) = ", V(fixed_point))
println("V ∋ [", V_min, ", ", maximum(V_samples), "]")
println("Minimial sample of V is at ", state_min)
println(
"V̇ ∋ [",
minimum(V̇_samples),
", ",
max((V̇(fixed_point) |> cpud)[], maximum(V̇_samples)),
"]",
)
# Plot results
using Plots
p1 = plot(xs, vs, V_samples, linetype = :contourf, title = "V", xlabel = "x", ylabel = "ẋ");
p1 = scatter!([0], [0], label = "Equilibrium");
p2 = plot(
xs,
vs,
V̇_samples,
linetype = :contourf,
title = "dV/dt",
xlabel = "x",
ylabel = "ẋ",
);
p2 = scatter!([0], [0], label = "Equilibrium");
plot(p1, p2)
=#
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ const GROUP = lowercase(get(ENV, "GROUP", "all"))
end
end

if GROUP == "gpu"
@time @safetestset "CUDA test - Damped SHO" begin
include("damped_sho_CUDA.jl")
end
end

if GROUP == "all" || GROUP == "unimplemented"
@time @safetestset "Errors for partially-implemented extensions" begin
include("unimplemented.jl")
Expand Down
Loading