Skip to content

Commit

Permalink
Organized results with PyCall plots
Browse files Browse the repository at this point in the history
  • Loading branch information
facusapienza21 committed Dec 24, 2023
1 parent a4b59a4 commit b63f917
Show file tree
Hide file tree
Showing 7 changed files with 344 additions and 11 deletions.
4 changes: 3 additions & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,6 @@ dependencies:
- pandas
- seaborn
- pip:
- pmagpy==4.2.106
- pmagpy==4.2.106
# GUI interface for VSCode
- PyQt5
252 changes: 252 additions & 0 deletions examples/2rotations.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,252 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Example of fir with two rotations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m\u001b[1m Activating\u001b[22m\u001b[39m project at `~/.julia/dev/SphereFit`\n"
]
}
],
"source": [
"using Pkg; Pkg.activate(\"../.\")\n",
"using Revise \n",
"\n",
"using LinearAlgebra, Statistics, Distributions \n",
"using OrdinaryDiffEq\n",
"using SciMLSensitivity\n",
"using Optimization, OptimizationOptimisers, OptimizationOptimJL\n",
"\n",
"using SphereFit"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"200"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"using Random\n",
"rng = Random.default_rng()\n",
"Random.seed!(rng, 000666)\n",
"# Fisher concentration parameter on observations (small = more dispersion)\n",
"κ = 200 "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's create a simple example consisting in two solid rotations around the globe with Fisher noise on top. "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1.0e-7"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Total time simulation\n",
"tspan = [0, 130.0]\n",
"# Number of sample points\n",
"N_samples = 50\n",
"# Times where we sample points\n",
"times_samples = sort(rand(sampler(Uniform(tspan[1], tspan[2])), N_samples))\n",
"\n",
"# Expected maximum angular deviation in one unit of time (degrees)\n",
"Δω₀ = 1.0 \n",
"# Angular velocity \n",
"ω₀ = Δω₀ * π / 180.0\n",
"# Change point\n",
"τ₀ = 65.0\n",
"# Angular momentum\n",
"L0 = ω₀ .* [1.0, 0.0, 0.0]\n",
"L1 = 0.5ω₀ .* [0.0, sqrt(2), sqrt(2)]\n",
"\n",
"# Solver tolerances \n",
"reltol = 1e-7\n",
"abstol = 1e-7"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×50 Matrix{Float64}:\n",
" -0.0385817 -0.102203 0.0130063 … -0.801308 -0.804544 -0.777151\n",
" 0.0229773 0.0624114 0.134454 0.593017 0.591397 0.625861\n",
" -0.998991 -0.992804 -0.990834 -0.0789676 -0.0543948 0.065833"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"function true_rotation!(du, u, p, t)\n",
" if t < τ₀\n",
" L = p[1]\n",
" else \n",
" L = p[2]\n",
" end\n",
" du .= cross(L, u)\n",
"end\n",
"\n",
"prob = ODEProblem(true_rotation!, [0.0, 0.0, -1.0], tspan, [L0, L1])\n",
"true_sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samples)\n",
"\n",
"# Add Fisher noise to true solution \n",
"X_noiseless = Array(true_sol)\n",
"X_true = mapslices(x -> rand(sampler(VonMisesFisher(x/norm(x), κ)), 1), X_noiseless, dims=1)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Let's make a plot of this using `PyCall` to call `cartopy` and `matplotlib`. "
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"3×50 Matrix{Float64}:\n",
" -0.0385817 -0.102203 0.0130063 … -0.801308 -0.804544 -0.777151\n",
" 0.0229773 0.0624114 0.134454 0.593017 0.591397 0.625861\n",
" -0.998991 -0.992804 -0.990834 -0.0789676 -0.0543948 0.065833"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_true"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"2×50 Matrix{Float64}:\n",
" 149.224 148.589 84.4747 104.938 … 143.496 143.681 141.155\n",
" -87.4262 -83.1222 -82.2367 -78.8478 -4.52923 -3.11813 3.77468"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"X_true_sph = cart2sph(X_true, radians=false)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Python plots"
]
},
{
"cell_type": "code",
"execution_count": 26,
"metadata": {},
"outputs": [],
"source": [
"using PyPlot, PyCall\n",
"\n",
"mpl_colors = pyimport(\"matplotlib.colors\")\n",
"mpl_colormap = pyimport(\"matplotlib.cm\")\n",
"sns = pyimport(\"seaborn\")\n",
"ccrs = pyimport(\"cartopy.crs\")\n",
"feature = pyimport(\"cartopy.feature\")\n",
"\n",
"plt.figure(figsize=(10,10))\n",
"ax = plt.axes(projection=ccrs.Orthographic(central_latitude=-20, central_longitude=150))\n",
"\n",
"# ax.coastlines()\n",
"ax.gridlines()\n",
"ax.set_global()\n",
"\n",
"cmap = mpl_colormap.get_cmap(\"viridis\")\n",
"\n",
"sns.scatterplot(ax=ax, x = X_true_sph[:,1], y=X_true_sph[:,2], \n",
" # hue = df_data['time'], s=50,\n",
" # palette=\"viridis\",\n",
" transform = ccrs.PlateCarree());\n",
"\n",
"plt.savefig(\"testing.pdf\", format=\"pdf\")\n",
"# plt.show()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Julia 1.9.4",
"language": "julia",
"name": "julia-1.9"
},
"language_info": {
"file_extension": ".jl",
"mimetype": "application/julia",
"name": "julia",
"version": "1.9.4"
},
"orig_nbformat": 4
},
"nbformat": 4,
"nbformat_minor": 2
}
56 changes: 53 additions & 3 deletions examples/2rotations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@ using Optimization, OptimizationOptimisers, OptimizationOptimJL

using SphereFit

##############################################################
############### Simulation of Simple Example ################
##############################################################

# Random seed
using Random
rng = Random.default_rng()
Expand Down Expand Up @@ -52,8 +56,54 @@ true_sol = solve(prob, Tsit5(), reltol=reltol, abstol=abstol, saveat=times_samp
X_noiseless = Array(true_sol)
X_true = mapslices(x -> rand(sampler(VonMisesFisher(x/norm(x), κ)), 1), X_noiseless, dims=1)

### Training example
##############################################################
####################### Training ###########################
##############################################################

data = SphereData(times=times_samples, directions=X_true, kappas=nothing)
params = SphereParameters(tmin=tspan[1], tmax=tspan[2], u0=[0.0, 0.0, -1.0], ωmax=2*ω₀, reltol=reltol, abstol=abstol)
params = SphereParameters(tmin=tspan[1], tmax=tspan[2],
u0=[0.0, 0.0, -1.0], ωmax=2*ω₀, reltol=reltol, abstol=abstol,
niter_ADAM=1000, niter_LBFGS=300)

results = train_sphere(data, params, rng, nothing)


##############################################################
###################### PyCall Plots #########################
##############################################################

using PyPlot, PyCall

X_true_sph = cart2sph(X_true, radians=false)

mpl_colors = pyimport("matplotlib.colors")
mpl_colormap = pyimport("matplotlib.cm")
sns = pyimport("seaborn")
ccrs = pyimport("cartopy.crs")
feature = pyimport("cartopy.feature")

plt.figure(figsize=(10,10))
ax = plt.axes(projection=ccrs.Orthographic(central_latitude=-20, central_longitude=150))

# ax.coastlines()
ax.gridlines()
ax.set_global()

cmap = mpl_colormap.get_cmap("viridis")
# norm = mpl_colors.Normalize(results.fit_times[1], results.fit_times[end])

sns.scatterplot(ax=ax, x = X_true_sph[1,:], y=X_true_sph[2, :],
hue = times_samples, s=50,
palette="viridis",
transform = ccrs.PlateCarree());

X_fit_sph = cart2sph(results.fit_directions, radians=false)

for i in 1:(length(results.fit_times)-1)
plt.plot([X_fit_sph[1,i], X_fit_sph[1,i+1]],
[X_fit_sph[2,i], X_fit_sph[2,i+1]],
linewidth=2, color="black",#cmap(norm(results.fit_times[i])),
transform = ccrs.Geodetic())
end

θ_trained, U, st = train_sphere(data, params, rng, nothing)
plt.savefig("examples/plot.pdf", format="pdf")
3 changes: 2 additions & 1 deletion src/SphereFit.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,9 @@ using SciMLSensitivity
using Optimization, OptimizationOptimisers, OptimizationOptimJL
using ComponentArrays: ComponentVector

export SphereParameters, SphereData
export SphereParameters, SphereData, cart2sph
export train_sphere
export Results

include("utils.jl")
include("types.jl")
Expand Down
15 changes: 12 additions & 3 deletions src/train.jl
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,9 @@ function train_sphere(data::AbstractData,
# Empirical error
l_ = mean(abs2, u_ .- data.directions)
return l_
# TO DO: add regularization
# ...
# ...
end

losses = Float64[]
Expand All @@ -67,16 +70,22 @@ function train_sphere(data::AbstractData,
optf = Optimization.OptimizationFunction((x, θ) -> loss(x), adtype)
optprob = Optimization.OptimizationProblem(optf, ComponentVector{Float64}(θ))

res1 = Optimization.solve(optprob, ADAM(0.001), callback=callback, maxiters=1000)
res1 = Optimization.solve(optprob, ADAM(0.001), callback=callback, maxiters=params.niter_ADAM)
println("Training loss after $(length(losses)) iterations: $(losses[end])")

optprob2 = Optimization.OptimizationProblem(optf, res1.u)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=300)
res2 = Optimization.solve(optprob2, Optim.LBFGS(), callback=callback, maxiters=params.niter_LBFGS)
println("Final training loss after $(length(losses)) iterations: $(losses[end])")

# Optimized NN parameters
θ_trained = res2.u

return θ_trained, U, st
# Final Fit
fit_times = collect(params.tmin:0.1:params.tmax)
fit_directions = predict(θ_trained, T=fit_times)

return Results(θ_trained=θ_trained, U=U, st=st,
fit_times=fit_times, fit_directions=fit_directions)
end


Expand Down
Loading

0 comments on commit b63f917

Please sign in to comment.