Skip to content

Commit

Permalink
Merge pull request #10 from ThibeauWouters/haukekoehn-main
Browse files Browse the repository at this point in the history
Haukekoehn main
  • Loading branch information
ThibeauWouters authored Oct 16, 2024
2 parents e52fb84 + 0c8b864 commit 3a80e3a
Show file tree
Hide file tree
Showing 32 changed files with 949 additions and 173 deletions.
9 changes: 9 additions & 0 deletions examples/data/GRB170817A.dat
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@
2018-03-22T09:48:53.000 radio-6GHz 19.922441241029667 0.27245838304876924
2018-05-17T03:20:05.000 radio-6GHz 20.009168497662955 0.2229989391775272
2018-06-02T01:53:41.000 radio-6GHz 20.039740523765573 0.23031103721382334
2017-12-06T00:27:17.000 bessellv 26.296373347364188 0.195421332140389060
2018-01-01T13:39:17.000 bessellv 26.589805035001397 0.249368462576732230
2018-01-29T17:15:17.000 bessellv 26.502457030827735 0.199120032069081440
2018-02-05T17:44:05.000 bessellv 26.576389632505034 0.2301226302973234
2018-03-14T15:05:41.000 bessellv 26.615365470633698 0.29252054534460403
2018-03-23T21:34:29.000 bessellv 26.901845323758742 0.359604853459607130
2018-06-10T07:53:41.000 bessellv 27.290652917195565 0.418413346202412440
2018-07-11T17:58:29.000 bessellv 27.57132597132611 0.42869202946981405
2018-08-14T20:22:29.000 bessellv 27.82141644869465 0.3285331705557674
2017-08-26T17:29:41.000 X-ray-1keV 32.88596422178184 0.6732995583434384
2017-09-02T04:03:17.000 X-ray-1keV 31.836985316766835 0.32594225959167356
2017-12-05T01:10:29.000 X-ray-1keV 30.645685060022323 0.13242989652457504
Expand Down
203 changes: 203 additions & 0 deletions examples/injection_tophat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,203 @@
"""Test run on GRB170817A data."""

import os
import jax
print(f"GPU found? {jax.devices()}")
import jax.numpy as jnp
jax.config.update("jax_enable_x64", True)
import numpy as np
import matplotlib.pyplot as plt
import corner

from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel
from fiesta.inference.injection import InjectionRecoveryAfterglowpy
from fiesta.inference.likelihood import EMLikelihood
from fiesta.inference.prior import Uniform, Composite
from fiesta.inference.fiesta import Fiesta
from fiesta.utils import load_event_data

import time
start_time = time.time()

################
### Preamble ###
################

jax.config.update("jax_enable_x64", True)

params = {"axes.grid": True,
"text.usetex" : True,
"font.family" : "serif",
"ytick.color" : "black",
"xtick.color" : "black",
"axes.labelcolor" : "black",
"axes.edgecolor" : "black",
"font.serif" : ["Computer Modern Serif"],
"xtick.labelsize": 16,
"ytick.labelsize": 16,
"axes.labelsize": 16,
"legend.fontsize": 16,
"legend.title_fontsize": 16,
"figure.titlesize": 16}

plt.rcParams.update(params)

default_corner_kwargs = dict(bins=40,
smooth=1.,
label_kwargs=dict(fontsize=16),
title_kwargs=dict(fontsize=16),
color="blue",
# quantiles=[],
# levels=[0.9],
plot_density=True,
plot_datapoints=False,
fill_contours=True,
max_n_ticks=4,
min_n_ticks=3,
save=False,
truth_color="red")

#############
### SETUP ###
#############



##############
### MODEL ###
##############

name = "tophat"
model_dir = f"../trained_models/afterglowpy/{name}/"
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"]

model = AfterglowpyLightcurvemodel(name,
model_dir,
filters = FILTERS)


###################
### INJECT ###
### AFTERGLOWPY ###
###################


injection_dict = {"inclination_EM": 1.3, "log10_E0": 52, "thetaCore": 0.2, "p": 2.5, "log10_n0": -1., "log10_epsilon_e": -1., "log10_epsilon_B": -4., "luminosity_distance": 40.0}
injection = InjectionRecoveryAfterglowpy(injection_dict, filters = FILTERS, N_datapoints = 50, error_budget = 0.5, tmin = 8, tmax = 800)
injection.create_injection()
data = injection.data


#############################
### PRIORS AND LIKELIHOOD ###
#############################

inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM'])
log10_E0 = Uniform(xmin=46.0, xmax=55.0, naming=['log10_E0'])
thetaCore = Uniform(xmin=0.01, xmax=np.pi/10, naming=['thetaCore'])
log10_n0 = Uniform(xmin=-7.0, xmax=1.0, naming=['log10_n0'])
p = Uniform(xmin=2.01, xmax=3.0, naming=['p'])
log10_epsilon_e = Uniform(xmin=-5.0, xmax=0.0, naming=['log10_epsilon_e'])
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B'])

# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance'])

prior_list = [inclination_EM,
log10_E0,
thetaCore,
log10_n0,
p,
log10_epsilon_e,
log10_epsilon_B
# luminosity_distance
]

prior = Composite(prior_list)

detection_limit = None
likelihood = EMLikelihood(model,
data,
FILTERS,
tmax = 500.0,
trigger_time=0,
detection_limit = detection_limit,
fixed_params={"luminosity_distance": 40.0}
)

##############
### FIESTA ###
##############

mass_matrix = jnp.eye(prior.n_dim)
eps = 5e-3
local_sampler_arg = {"step_size": mass_matrix * eps}

# Save for postprocessing
outdir = f"./injection_tophat/"
if not os.path.exists(outdir):
os.makedirs(outdir)

fiesta = Fiesta(likelihood,
prior,
n_chains = 1_000,
n_loop_training = 7,
n_loop_production = 3,
num_layers = 4,
hidden_size = [64, 64],
n_epochs = 20,
n_local_steps = 50,
n_global_steps = 200,
local_sampler_arg=local_sampler_arg,
outdir = outdir)

fiesta.sample(jax.random.PRNGKey(42))

fiesta.print_summary()

name = outdir + f'results_training.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, log_prob=log_prob, local_accs=local_accs,
global_accs=global_accs, loss_vals=loss_vals)

# - production phase
name = outdir + f'results_production.npz'
print(f"Saving samples to {name}")
state = fiesta.Sampler.get_sampler_state(training=False)
chains, log_prob, local_accs, global_accs = state["chains"], state[
"log_prob"], state["local_accs"], state["global_accs"]
local_accs = jnp.mean(local_accs, axis=0)
global_accs = jnp.mean(global_accs, axis=0)
np.savez(name, chains=chains, log_prob=log_prob,
local_accs=local_accs, global_accs=global_accs)

################
### PLOTTING ###
################
# Fixed names: do not include them in the plotting, as will break corner
parameter_names = prior.naming
truths = [injection_dict[key] for key in parameter_names]

n_chains, n_steps, n_dim = np.shape(chains)
samples = np.reshape(chains, (n_chains * n_steps, n_dim))
samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption

corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs)
plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight')
plt.close()

end_time = time.time()
runtime_seconds = end_time - start_time
number_of_minutes = runtime_seconds // 60
number_of_seconds = np.round(runtime_seconds % 60, 2)
print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s")

print("Plotting lightcurves")
fiesta.plot_lightcurves()
print("Plotting lightcurves . . . done")

print("DONE")
Binary file added examples/injection_tophat/corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/injection_tophat/lightcurves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added examples/injection_tophat/results_production.npz
Binary file not shown.
Binary file added examples/injection_tophat/results_training.npz
Binary file not shown.
Binary file modified examples/outdir_GRB170817_tophat/corner.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/outdir_GRB170817_tophat/lightcurves.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified examples/outdir_GRB170817_tophat/results_production.npz
Binary file not shown.
Binary file modified examples/outdir_GRB170817_tophat/results_training.npz
Binary file not shown.
15 changes: 8 additions & 7 deletions examples/run_GRB170817_tophat.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,29 +68,30 @@

name = "tophat"
model_dir = f"../trained_models/afterglowpy/{name}/"
FILTERS = ["radio-3GHz"] # TODO: add more filters here and tackle the full problem
FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"]

model = AfterglowpyLightcurvemodel(name,
model_dir,
filters = FILTERS)


############
### DATA ###
############

data = load_event_data("./data/GRB170817A_toy.dat") # only one filter of the GRB170817A data
data = load_event_data("./data/GRB170817A.dat") # only one filter of the GRB170817A data

#############################
### PRIORS AND LIKELIHOOD ###
#############################

inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM'])
log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0'])
log10_E0 = Uniform(xmin=46.0, xmax=55.0, naming=['log10_E0'])
thetaCore = Uniform(xmin=0.01, xmax=np.pi/10, naming=['thetaCore'])
log10_n0 = Uniform(xmin=-6.0, xmax=3.0, naming=['log10_n0'])
log10_n0 = Uniform(xmin=-7.0, xmax=1.0, naming=['log10_n0'])
p = Uniform(xmin=2.01, xmax=3.0, naming=['p'])
log10_epsilon_e = Uniform(xmin=-5.0, xmax=0.0, naming=['log10_epsilon_e'])
log10_epsilon_B = Uniform(xmin=-10.0, xmax=0.0, naming=['log10_epsilon_B'])
log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B'])

# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance'])

Expand All @@ -110,7 +111,7 @@
likelihood = EMLikelihood(model,
data,
FILTERS,
tmax = 300.0,
tmax = 500.0,
trigger_time=trigger_time,
detection_limit = detection_limit,
fixed_params={"luminosity_distance": 40.0}
Expand Down Expand Up @@ -142,7 +143,7 @@
local_sampler_arg=local_sampler_arg,
outdir = outdir)

fiesta.sample(jax.random.PRNGKey(0))
fiesta.sample(jax.random.PRNGKey(42))

fiesta.print_summary()

Expand Down
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ install_requires =
astropy
sncosmo
flowMC
joblib

python_requires = >=3.10

Expand Down
41 changes: 39 additions & 2 deletions src/fiesta.egg-info/PKG-INFO
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,43 @@ Description-Content-Type: text/markdown
License-File: LICENSE
Requires-Dist: jax>=0.4.24
Requires-Dist: jaxlib>=0.4.24
Requires-Dist: numpy<2.0.0
Requires-Dist: pandas<2.0.0
Requires-Dist: jaxtyping
Requires-Dist: beartype
Requires-Dist: tqdm
Requires-Dist: scipy<=1.14.0
Requires-Dist: ml_collections
Requires-Dist: astropy
Requires-Dist: sncosmo
Requires-Dist: flowMC
Requires-Dist: joblib

# fiesta
fiesta: Fast inference of electromagnetic signals with JAX
# fiesta 🎉

`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x

![fiesta logo](docs/fiesta_logo.jpeg)

**NOTE:** `fiesta` is currently under development -- stay tuned!

## Installation

pip installation is currently work in progress. Install from source by cloning this Github repository and running
```
pip install -e .
```

NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`).

## Training surrogate models

To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models`

- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model.
- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure.

## Examples

- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model.
- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon.
11 changes: 11 additions & 0 deletions src/fiesta.egg-info/requires.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,13 @@
jax>=0.4.24
jaxlib>=0.4.24
numpy<2.0.0
pandas<2.0.0
jaxtyping
beartype
tqdm
scipy<=1.14.0
ml_collections
astropy
sncosmo
flowMC
joblib
Loading

0 comments on commit 3a80e3a

Please sign in to comment.