Skip to content

Commit

Permalink
Merge pull request #34 from SciML/gpu_support
Browse files Browse the repository at this point in the history
Add test of GPU support
  • Loading branch information
nicholaskl97 authored Feb 5, 2025
2 parents fd285a3 + 89bb9ba commit f273837
Show file tree
Hide file tree
Showing 3 changed files with 163 additions and 1 deletion.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ ForwardDiff = "0.10"
JuMP = "1"
Lux = "1"
LuxCore = "1.1.0"
LuxCUDA = "0.3.3"
ModelingToolkit = "9.51"
NLopt = "1"
NeuralPDE = "5.17"
Expand All @@ -34,6 +35,7 @@ Boltz = "4544d5e4-abc5-4dea-817f-29e4c205d9c8"
CSDP = "0a46da34-8e4b-519e-b418-48813639ff34"
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
Lux = "b2108857-7c20-44ae-9111-449ecde12c47"
LuxCUDA = "d0bbae9a-e099-4d5b-a835-1c6931763bda"
NLopt = "76087f3c-5699-56af-9a33-bf431cd00edd"
NeuralPDE = "315f7962-48a3-4962-8226-d0f33b1235f0"
Optimization = "7f7a1694-90dd-40f0-9382-eb1efda571ba"
Expand All @@ -44,4 +46,4 @@ SafeTestsets = "1bc83da4-3b8d-516f-aca4-4fe02f6d838f"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

[targets]
test = ["SafeTestsets", "Test", "Lux", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "NLopt", "Random", "NeuralPDE", "CSDP", "Boltz", "ComponentArrays"]
test = ["SafeTestsets", "Test", "Lux", "Optimization", "OptimizationOptimJL", "OptimizationOptimisers", "NLopt", "Random", "NeuralPDE", "CSDP", "Boltz", "ComponentArrays", "LuxCUDA"]
154 changes: 154 additions & 0 deletions test/damped_sho_CUDA.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
using NeuralPDE, NeuralLyapunov
import Optimization, OptimizationOptimisers, OptimizationOptimJL
using Random
using Lux, LuxCUDA, ComponentArrays
using Test, LinearAlgebra, ForwardDiff

Random.seed!(200)

println("Damped Simple Harmonic Oscillator")

######################### Define dynamics and domain ##########################

"Simple Harmonic Oscillator Dynamics"
function f(state, p, t)
pos = state[1]
vel = state[2]
vcat(vel, -vel - pos)
end
lb = [-2.0, -2.0];
ub = [2.0, 2.0];
fixed_point = [0.0, 0.0];
dynamics = ODEFunction(f; sys = SciMLBase.SymbolCache([:x, :v]))

####################### Specify neural Lyapunov problem #######################

# Define neural network discretization
dim_state = length(lb)
dim_hidden = 20
chain = Chain(
Dense(dim_state, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, dim_hidden, tanh),
Dense(dim_hidden, 1)
)
const gpud = gpu_device()
ps = Lux.initialparameters(Random.default_rng(), chain) |> ComponentArray |> gpud |> f32

# Define training strategy
strategy = QuasiRandomTraining(2500)
discretization = PhysicsInformedNN(chain, strategy; init_params = ps)

# Define neural Lyapunov structure
structure = UnstructuredNeuralLyapunov()
minimization_condition = StrictlyPositiveDefinite(C = 0.1)

# Define Lyapunov decrease condition
# This damped SHO has exponential decrease at a rate of k = 0.5, so we train to certify that
decrease_condition = ExponentialStability(0.5)

# Construct neural Lyapunov specification
spec = NeuralLyapunovSpecification(
structure,
minimization_condition,
decrease_condition
)

############################# Construct PDESystem #############################

@named pde_system = NeuralLyapunovPDESystem(
dynamics,
lb,
ub,
spec;
)

######################## Construct OptimizationProblem ########################

prob = discretize(pde_system, discretization)
sym_prob = symbolic_discretize(pde_system, discretization)

########################## Solve OptimizationProblem ##########################

res = Optimization.solve(prob, OptimizationOptimisers.Adam(0.01); maxiters = 300)
prob = Optimization.remake(prob, u0 = res.u)
res = Optimization.solve(prob, OptimizationOptimisers.Adam(); maxiters = 300)
prob = Optimization.remake(prob, u0 = res.u)
res = Optimization.solve(prob, OptimizationOptimJL.BFGS(); maxiters = 300)

###################### Get numerical numerical functions ######################
V, V̇ = get_numerical_lyapunov_function(
discretization.phi,
(; φ1 = res.u),
structure,
f,
fixed_point
)

################################## Simulate ###################################
Δx = (ub[1] - lb[1]) / 100
Δv = (ub[2] - lb[2]) / 100
xs = lb[1]:Δx:ub[1]
vs = lb[2]:Δv:ub[2]
states = Iterators.map(collect, Iterators.product(xs, vs))
V_samples_gpu = vec(V(hcat(states...)))
V̇_samples_gpu = vec((hcat(states...)))

const cpud = cpu_device()
V_samples = V_samples_gpu |> cpud
V̇_samples = V̇_samples_gpu |> cpud

#################################### Tests ####################################

# Network structure should enforce nonegativeness of V
V0 = (V(fixed_point) |> cpud)[]
V_min, i_min = findmin(V_samples)
state_min = collect(states)[i_min]
V_min, state_min = if V0 V_min
V0, fixed_point
else
V_min, state_min
end
@test V_min -1e-2

# Trained for V's minimum to be near the fixed point
@test all(abs.(state_min .- fixed_point) .≤ 10 * [Δx, Δv])

# Check local negative semidefiniteness of V̇ at fixed point
@test ((fixed_point) |> cpud)[] == 0.0
@test maximum(abs, ForwardDiff.gradient(x -> ((x) |> cpud)[], fixed_point)) < 0.1
@test_broken maximum(eigvals(ForwardDiff.hessian(x -> ((x) |> cpud)[], fixed_point))) 0.0

# V̇ should be negative almost everywhere
@test sum(V̇_samples .> 0) / length(V̇_samples) < 5e-3

#=
# Print statistics
println("V(0.,0.) = ", V(fixed_point))
println("V ∋ [", V_min, ", ", maximum(V_samples), "]")
println("Minimial sample of V is at ", state_min)
println(
"V̇ ∋ [",
minimum(V̇_samples),
", ",
max((V̇(fixed_point) |> cpud)[], maximum(V̇_samples)),
"]",
)
# Plot results
using Plots
p1 = plot(xs, vs, V_samples, linetype = :contourf, title = "V", xlabel = "x", ylabel = "ẋ");
p1 = scatter!([0], [0], label = "Equilibrium");
p2 = plot(
xs,
vs,
V̇_samples,
linetype = :contourf,
title = "dV/dt",
xlabel = "x",
ylabel = "ẋ",
);
p2 = scatter!([0], [0], label = "Equilibrium");
plot(p1, p2)
=#
6 changes: 6 additions & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,12 @@ const GROUP = lowercase(get(ENV, "GROUP", "all"))
end
end

if GROUP == "gpu"
@time @safetestset "CUDA test - Damped SHO" begin
include("damped_sho_CUDA.jl")
end
end

if GROUP == "all" || GROUP == "unimplemented"
@time @safetestset "Errors for partially-implemented extensions" begin
include("unimplemented.jl")
Expand Down

0 comments on commit f273837

Please sign in to comment.