diff --git a/examples/data/GRB170817A.dat b/examples/data/GRB170817A.dat index ed55e3f..f166ef0 100755 --- a/examples/data/GRB170817A.dat +++ b/examples/data/GRB170817A.dat @@ -28,6 +28,15 @@ 2018-03-22T09:48:53.000 radio-6GHz 19.922441241029667 0.27245838304876924 2018-05-17T03:20:05.000 radio-6GHz 20.009168497662955 0.2229989391775272 2018-06-02T01:53:41.000 radio-6GHz 20.039740523765573 0.23031103721382334 +2017-12-06T00:27:17.000 bessellv 26.296373347364188 0.195421332140389060 +2018-01-01T13:39:17.000 bessellv 26.589805035001397 0.249368462576732230 +2018-01-29T17:15:17.000 bessellv 26.502457030827735 0.199120032069081440 +2018-02-05T17:44:05.000 bessellv 26.576389632505034 0.2301226302973234 +2018-03-14T15:05:41.000 bessellv 26.615365470633698 0.29252054534460403 +2018-03-23T21:34:29.000 bessellv 26.901845323758742 0.359604853459607130 +2018-06-10T07:53:41.000 bessellv 27.290652917195565 0.418413346202412440 +2018-07-11T17:58:29.000 bessellv 27.57132597132611 0.42869202946981405 +2018-08-14T20:22:29.000 bessellv 27.82141644869465 0.3285331705557674 2017-08-26T17:29:41.000 X-ray-1keV 32.88596422178184 0.6732995583434384 2017-09-02T04:03:17.000 X-ray-1keV 31.836985316766835 0.32594225959167356 2017-12-05T01:10:29.000 X-ray-1keV 30.645685060022323 0.13242989652457504 diff --git a/examples/injection_tophat.py b/examples/injection_tophat.py new file mode 100644 index 0000000..fd17ed0 --- /dev/null +++ b/examples/injection_tophat.py @@ -0,0 +1,203 @@ +"""Test run on GRB170817A data.""" + +import os +import jax +print(f"GPU found? {jax.devices()}") +import jax.numpy as jnp +jax.config.update("jax_enable_x64", True) +import numpy as np +import matplotlib.pyplot as plt +import corner + +from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.inference.injection import InjectionRecoveryAfterglowpy +from fiesta.inference.likelihood import EMLikelihood +from fiesta.inference.prior import Uniform, Composite +from fiesta.inference.fiesta import Fiesta +from fiesta.utils import load_event_data + +import time +start_time = time.time() + +################ +### Preamble ### +################ + +jax.config.update("jax_enable_x64", True) + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +default_corner_kwargs = dict(bins=40, + smooth=1., + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + save=False, + truth_color="red") + +############# +### SETUP ### +############# + + + +############## +### MODEL ### +############## + +name = "tophat" +model_dir = f"../trained_models/afterglowpy/{name}/" +FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"] + +model = AfterglowpyLightcurvemodel(name, + model_dir, + filters = FILTERS) + + +################### +### INJECT ### +### AFTERGLOWPY ### +################### + + +injection_dict = {"inclination_EM": 1.3, "log10_E0": 52, "thetaCore": 0.2, "p": 2.5, "log10_n0": -1., "log10_epsilon_e": -1., "log10_epsilon_B": -4., "luminosity_distance": 40.0} +injection = InjectionRecoveryAfterglowpy(injection_dict, filters = FILTERS, N_datapoints = 50, error_budget = 0.5, tmin = 8, tmax = 800) +injection.create_injection() +data = injection.data + + +############################# +### PRIORS AND LIKELIHOOD ### +############################# + +inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM']) +log10_E0 = Uniform(xmin=46.0, xmax=55.0, naming=['log10_E0']) +thetaCore = Uniform(xmin=0.01, xmax=np.pi/10, naming=['thetaCore']) +log10_n0 = Uniform(xmin=-7.0, xmax=1.0, naming=['log10_n0']) +p = Uniform(xmin=2.01, xmax=3.0, naming=['p']) +log10_epsilon_e = Uniform(xmin=-5.0, xmax=0.0, naming=['log10_epsilon_e']) +log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B']) + +# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance']) + +prior_list = [inclination_EM, + log10_E0, + thetaCore, + log10_n0, + p, + log10_epsilon_e, + log10_epsilon_B + # luminosity_distance +] + +prior = Composite(prior_list) + +detection_limit = None +likelihood = EMLikelihood(model, + data, + FILTERS, + tmax = 500.0, + trigger_time=0, + detection_limit = detection_limit, + fixed_params={"luminosity_distance": 40.0} +) + +############## +### FIESTA ### +############## + +mass_matrix = jnp.eye(prior.n_dim) +eps = 5e-3 +local_sampler_arg = {"step_size": mass_matrix * eps} + +# Save for postprocessing +outdir = f"./injection_tophat/" +if not os.path.exists(outdir): + os.makedirs(outdir) + +fiesta = Fiesta(likelihood, + prior, + n_chains = 1_000, + n_loop_training = 7, + n_loop_production = 3, + num_layers = 4, + hidden_size = [64, 64], + n_epochs = 20, + n_local_steps = 50, + n_global_steps = 200, + local_sampler_arg=local_sampler_arg, + outdir = outdir) + +fiesta.sample(jax.random.PRNGKey(42)) + +fiesta.print_summary() + +name = outdir + f'results_training.npz' +print(f"Saving samples to {name}") +state = fiesta.Sampler.get_sampler_state(training=True) +chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[ +"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, log_prob=log_prob, local_accs=local_accs, + global_accs=global_accs, loss_vals=loss_vals) + +# - production phase +name = outdir + f'results_production.npz' +print(f"Saving samples to {name}") +state = fiesta.Sampler.get_sampler_state(training=False) +chains, log_prob, local_accs, global_accs = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, chains=chains, log_prob=log_prob, + local_accs=local_accs, global_accs=global_accs) + +################ +### PLOTTING ### +################ +# Fixed names: do not include them in the plotting, as will break corner +parameter_names = prior.naming +truths = [injection_dict[key] for key in parameter_names] + +n_chains, n_steps, n_dim = np.shape(chains) +samples = np.reshape(chains, (n_chains * n_steps, n_dim)) +samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption + +corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs) +plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight') +plt.close() + +end_time = time.time() +runtime_seconds = end_time - start_time +number_of_minutes = runtime_seconds // 60 +number_of_seconds = np.round(runtime_seconds % 60, 2) +print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s") + +print("Plotting lightcurves") +fiesta.plot_lightcurves() +print("Plotting lightcurves . . . done") + +print("DONE") \ No newline at end of file diff --git a/examples/injection_tophat/corner.png b/examples/injection_tophat/corner.png new file mode 100644 index 0000000..9b4aa30 Binary files /dev/null and b/examples/injection_tophat/corner.png differ diff --git a/examples/injection_tophat/lightcurves.png b/examples/injection_tophat/lightcurves.png new file mode 100644 index 0000000..3df8a4f Binary files /dev/null and b/examples/injection_tophat/lightcurves.png differ diff --git a/examples/injection_tophat/results_production.npz b/examples/injection_tophat/results_production.npz new file mode 100644 index 0000000..73ece35 Binary files /dev/null and b/examples/injection_tophat/results_production.npz differ diff --git a/examples/injection_tophat/results_training.npz b/examples/injection_tophat/results_training.npz new file mode 100644 index 0000000..8557168 Binary files /dev/null and b/examples/injection_tophat/results_training.npz differ diff --git a/examples/outdir_GRB170817_tophat/corner.png b/examples/outdir_GRB170817_tophat/corner.png index 799d2fd..86a071a 100644 Binary files a/examples/outdir_GRB170817_tophat/corner.png and b/examples/outdir_GRB170817_tophat/corner.png differ diff --git a/examples/outdir_GRB170817_tophat/lightcurves.png b/examples/outdir_GRB170817_tophat/lightcurves.png index 58372b3..602f510 100644 Binary files a/examples/outdir_GRB170817_tophat/lightcurves.png and b/examples/outdir_GRB170817_tophat/lightcurves.png differ diff --git a/examples/outdir_GRB170817_tophat/results_production.npz b/examples/outdir_GRB170817_tophat/results_production.npz index 1808c3b..e224859 100644 Binary files a/examples/outdir_GRB170817_tophat/results_production.npz and b/examples/outdir_GRB170817_tophat/results_production.npz differ diff --git a/examples/outdir_GRB170817_tophat/results_training.npz b/examples/outdir_GRB170817_tophat/results_training.npz index 735016c..6b629ee 100644 Binary files a/examples/outdir_GRB170817_tophat/results_training.npz and b/examples/outdir_GRB170817_tophat/results_training.npz differ diff --git a/examples/run_GRB170817_tophat.py b/examples/run_GRB170817_tophat.py index f7d0bcc..4da6301 100644 --- a/examples/run_GRB170817_tophat.py +++ b/examples/run_GRB170817_tophat.py @@ -68,29 +68,30 @@ name = "tophat" model_dir = f"../trained_models/afterglowpy/{name}/" -FILTERS = ["radio-3GHz"] # TODO: add more filters here and tackle the full problem +FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"] model = AfterglowpyLightcurvemodel(name, model_dir, filters = FILTERS) + ############ ### DATA ### ############ -data = load_event_data("./data/GRB170817A_toy.dat") # only one filter of the GRB170817A data +data = load_event_data("./data/GRB170817A.dat") # only one filter of the GRB170817A data ############################# ### PRIORS AND LIKELIHOOD ### ############################# inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM']) -log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0']) +log10_E0 = Uniform(xmin=46.0, xmax=55.0, naming=['log10_E0']) thetaCore = Uniform(xmin=0.01, xmax=np.pi/10, naming=['thetaCore']) -log10_n0 = Uniform(xmin=-6.0, xmax=3.0, naming=['log10_n0']) +log10_n0 = Uniform(xmin=-7.0, xmax=1.0, naming=['log10_n0']) p = Uniform(xmin=2.01, xmax=3.0, naming=['p']) log10_epsilon_e = Uniform(xmin=-5.0, xmax=0.0, naming=['log10_epsilon_e']) -log10_epsilon_B = Uniform(xmin=-10.0, xmax=0.0, naming=['log10_epsilon_B']) +log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B']) # luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance']) @@ -110,7 +111,7 @@ likelihood = EMLikelihood(model, data, FILTERS, - tmax = 300.0, + tmax = 500.0, trigger_time=trigger_time, detection_limit = detection_limit, fixed_params={"luminosity_distance": 40.0} @@ -142,7 +143,7 @@ local_sampler_arg=local_sampler_arg, outdir = outdir) -fiesta.sample(jax.random.PRNGKey(0)) +fiesta.sample(jax.random.PRNGKey(42)) fiesta.print_summary() diff --git a/setup.cfg b/setup.cfg index 3951e3a..2eb0bea 100644 --- a/setup.cfg +++ b/setup.cfg @@ -27,6 +27,7 @@ install_requires = astropy sncosmo flowMC + joblib python_requires = >=3.10 diff --git a/src/fiesta.egg-info/PKG-INFO b/src/fiesta.egg-info/PKG-INFO index b963d60..05d5959 100644 --- a/src/fiesta.egg-info/PKG-INFO +++ b/src/fiesta.egg-info/PKG-INFO @@ -12,6 +12,43 @@ Description-Content-Type: text/markdown License-File: LICENSE Requires-Dist: jax>=0.4.24 Requires-Dist: jaxlib>=0.4.24 +Requires-Dist: numpy<2.0.0 +Requires-Dist: pandas<2.0.0 +Requires-Dist: jaxtyping +Requires-Dist: beartype +Requires-Dist: tqdm +Requires-Dist: scipy<=1.14.0 +Requires-Dist: ml_collections +Requires-Dist: astropy +Requires-Dist: sncosmo +Requires-Dist: flowMC +Requires-Dist: joblib -# fiesta -fiesta: Fast inference of electromagnetic signals with JAX +# fiesta 🎉 + +`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x + +![fiesta logo](docs/fiesta_logo.jpeg) + +**NOTE:** `fiesta` is currently under development -- stay tuned! + +## Installation + +pip installation is currently work in progress. Install from source by cloning this Github repository and running +``` +pip install -e . +``` + +NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`). + +## Training surrogate models + +To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models` + +- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model. +- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure. + +## Examples + +- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model. +- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon. diff --git a/src/fiesta.egg-info/requires.txt b/src/fiesta.egg-info/requires.txt index 683585e..283afba 100644 --- a/src/fiesta.egg-info/requires.txt +++ b/src/fiesta.egg-info/requires.txt @@ -1,2 +1,13 @@ jax>=0.4.24 jaxlib>=0.4.24 +numpy<2.0.0 +pandas<2.0.0 +jaxtyping +beartype +tqdm +scipy<=1.14.0 +ml_collections +astropy +sncosmo +flowMC +joblib diff --git a/src/fiesta/inference/injection.py b/src/fiesta/inference/injection.py index 235c37f..0977dc1 100644 --- a/src/fiesta/inference/injection.py +++ b/src/fiesta/inference/injection.py @@ -10,6 +10,11 @@ from fiesta.inference.lightcurve_model import LightcurveModel from fiesta.utils import mag_app_from_mag_abs +from fiesta.utils import Filter +from fiesta.constants import days_to_seconds, c +from fiesta import conversions + +import afterglowpy as grb # TODO: get the parser going def get_parser(**kwargs): @@ -36,7 +41,6 @@ def __init__(self, self.model = model # Ensure given filters are also in the trained model - if filters is None: filters = model.filters else: @@ -44,7 +48,7 @@ def __init__(self, if filt not in model.filters: print(f"Filter {filt} not in model filters. Removing from list") filters.remove(filt) - + print(f"Creating injection with filters: {filters}") self.filters = filters self.injection_dict = injection_dict @@ -82,4 +86,95 @@ def create_timegrid(self): """Create a time grid for the injection.""" # TODO: create more interesting grids than uniform and same accross all filters? - return np.linspace(self.tmin, self.tmax, self.N_datapoints) \ No newline at end of file + return np.linspace(self.tmin, self.tmax, self.N_datapoints) + + + +class InjectionRecoveryAfterglowpy: + + def __init__(self, + injection_dict: dict[str, Float], + filters: list[str], + jet_type = -1, + tmin: Float = 0.1, + tmax: Float = 1000.0, + N_datapoints: int = 10, + error_budget: Float = 1.0, + randomize_nondetections: bool = False, + randomize_nondetections_fraction: Float = 0.2): + + self.jet_type = jet_type + # Ensure given filters are also in the trained model + + if filters is None: + filters = model.filters + + self.filters = [Filter(filt) for filt in filters] + print(f"Creating injection with filters: {filters}") + self.injection_dict = injection_dict + self.tmin = tmin + self.tmax = tmax + self.N_datapoints = N_datapoints + self.error_budget = error_budget + self.randomize_nondetections = randomize_nondetections + self.randomize_nondetections_fraction = randomize_nondetections_fraction + + def create_injection(self): + """Create a synthetic injection from the given model and parameters.""" + + points = np.random.multinomial(self.N_datapoints, [1/len(self.filters)]*len(self.filters)) # random number of datapoints in each filter + self.data = {} + + for npoints, filt in zip(points, self.filters): + self.injection_dict["nu"] = filt.nu + times = self.create_timegrid(npoints) + mJys = self._call_afterglowpy(times*days_to_seconds, self.injection_dict) + magnitudes = conversions.mJys_to_mag_np(mJys) + mag_err = self.error_budget * np.ones_like(times) + self.data[filt.name] = np.array([times, magnitudes, mag_err]).T + + + def _call_afterglowpy(self, + times_afterglowpy: Array, + params_dict: dict[str, Float]) -> Float[Array, "n_times"]: + """ + Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. + The output will be a set of mJys. + + Args: + Float[Array, "n_times"]: The flux density in mJys at the given times. + """ + + # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z + Z = {} + + Z["jetType"] = params_dict.get("jetType", self.jet_type) + Z["specType"] = params_dict.get("specType", 0) + Z["z"] = params_dict.get("z", 0.0) + Z["xi_N"] = params_dict.get("xi_N", 1.0) + + Z["E0"] = 10 ** params_dict["log10_E0"] + Z["thetaCore"] = params_dict["thetaCore"] + Z["n0"] = 10 ** params_dict["log10_n0"] + Z["p"] = params_dict["p"] + Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] + Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] + Z["d_L"] = params_dict.get("luminosity_distance", 1e-5)*1e6*3.086e18 + if "inclination_EM" in list(params_dict.keys()): + Z["thetaObs"] = params_dict["inclination_EM"] + else: + Z["thetaObs"] = params_dict["thetaObs"] + if self.jet_type == 1 or self.jet_type == 4: + Z["b"] = params_dict["b"] + if "thetaWing" in list(params_dict.keys()): + Z["thetaWing"] = params_dict["thetaWing"] + + # Afterglowpy returns flux in mJys + mJys = grb.fluxDensity(times_afterglowpy, params_dict["nu"], **Z) + return mJys + + + def create_timegrid(self, npoints): + """Create a time grid for the injection.""" + + return np.linspace(self.tmin, self.tmax, npoints) \ No newline at end of file diff --git a/src/fiesta/inference/lightcurve_model.py b/src/fiesta/inference/lightcurve_model.py index 779a527..4dfd7a0 100644 --- a/src/fiesta/inference/lightcurve_model.py +++ b/src/fiesta/inference/lightcurve_model.py @@ -9,7 +9,7 @@ from functools import partial from beartype import beartype as typechecker from flax.training.train_state import TrainState -import joblib +import pickle import fiesta.train.neuralnets as fiesta_nn from fiesta.utils import MinMaxScalerJax, inverse_svd_transform @@ -140,9 +140,11 @@ def __init__(self, self.load_networks() def load_metadata(self) -> None: - self.metadata_filename = os.path.join(self.directory, f"{self.name}.joblib") + self.metadata_filename = os.path.join(self.directory, f"{self.name}_metadata.pkl") assert os.path.exists(self.metadata_filename), f"Metadata file {self.metadata_filename} not found - check the directory {self.directory}" - self.metadata = joblib.load(self.metadata_filename) + meta_file = open(self.metadata_filename, "rb") + self.metadata = pickle.load(meta_file) + meta_file.close() def load_filters(self, filters: list[str] = None) -> None: # Save those filters that were given and that were trained and store here already @@ -163,18 +165,17 @@ def load_filters(self, filters: list[str] = None) -> None: print(f"Loaded SurrogateLightcurveModel with filters {filters}") def load_scalers(self): - min_val, max_val = self.metadata["X_scaler_min"], self.metadata["X_scaler_max"] - self.X_scaler = MinMaxScalerJax(min_val=min_val, max_val=max_val) - - min_val, max_val = self.metadata["y_scaler_min"], self.metadata["y_scaler_max"] - self.y_scaler = {} - for filt in self.filters: - self.y_scaler[filt] = MinMaxScalerJax(min_val=min_val[filt], max_val=max_val[filt]) + self.X_scaler, self.y_scaler = {}, {} + for filt in self.filters: + self.X_scaler[filt] = MinMaxScalerJax(min_val=self.metadata[filt]["X_scaler_min"], max_val=self.metadata[filt]["X_scaler_max"]) + self.y_scaler[filt] = self.metadata[filt]["y_scaler"] + def load_times(self, times: Array = None) -> None: - # TODO: check for time range and trained model time range if times is None: times = jnp.array(self.metadata["times"]) + if times.min()self.metadata["times"].max(): + times = jnp.array(self.metadata["times"]) self.times = times self.tmin = jnp.min(times) self.tmax = jnp.max(times) @@ -200,7 +201,7 @@ def project_input(self, x: Array) -> dict[str, Array]: Returns: dict[str, Array]: Transformed input array """ - x_tilde = {filter: self.X_scaler.transform(x) for filter in self.filters} + x_tilde = {filter: self.X_scaler[filter].transform(x) for filter in self.filters} return x_tilde def compute_output(self, x: dict[str, Array]) -> dict[str, Array]: @@ -245,8 +246,8 @@ def __init__(self, """ super().__init__(name=name, directory=directory, times=times, filters=filters) - self.VA = self.metadata["VA"] - self.svd_ncoeff = self.metadata["svd_ncoeff"] + self.VA = {filt: self.metadata[filt]["VA"] for filt in filters} + self.svd_ncoeff = {filt: self.metadata[filt]["svd_ncoeff"] for filt in filters} def load_parameter_names(self) -> None: raise NotImplementedError @@ -261,7 +262,7 @@ def project_output(self, y: dict[str, Array]) -> dict[str, Array]: Returns: dict[str, Array]: _description_ """ - output = {filter: inverse_svd_transform(y[filter], self.VA[filter], self.svd_ncoeff) for filter in self.filters} + output = {filter: inverse_svd_transform(y[filter], self.VA[filter], self.svd_ncoeff[filter]) for filter in self.filters} return super().project_output(output) class BullaLightcurveModel(SVDSurrogateLightcurveModel): diff --git a/src/fiesta/train/Benchmarker.py b/src/fiesta/train/Benchmarker.py new file mode 100644 index 0000000..ea77781 --- /dev/null +++ b/src/fiesta/train/Benchmarker.py @@ -0,0 +1,236 @@ +from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +import afterglowpy as grb +from fiesta.constants import days_to_seconds, c +from fiesta import conversions +from fiesta import utils + +from jaxtyping import Array, Float, Int + +import tqdm +import os +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as colors +from matplotlib.cm import ScalarMappable + +from scipy.integrate import trapezoid + +class Benchmarker: + + def __init__(self, + name: str, + model_dir: str, + filters: list[str], + n_test_data: int = 3000, + remake_test_data: bool = False, + metric_name: str = "$\\mathcal{L}_\\inf$", + jet_type = -1 + ) -> None: + + self.name = name + self.model_dir = model_dir + self.load_filters(filters) + self.model = AfterglowpyLightcurvemodel(name = self.name, + directory = self.model_dir, + filters = filters) + self.times = self.model.times + self._times_afterglowpy = self.times * days_to_seconds + self.jet_type = jet_type + + self.parameter_names = self.model.metadata["parameter_names"] + train_X_raw = np.load(self.model_dir+"raw_data_training.npz")["X_raw"] + self.parameter_boundaries = np.array([np.min(train_X_raw, axis = 0), np.max(train_X_raw, axis = 0)]) + + if os.path.exists(self.model_dir+"/raw_data_test.npz") and not remake_test_data: + self.load_test_data() + else: + self.get_test_data(n_test_data) + + self.metric_name = metric_name + mask = np.logical_and(self.times>8, self.times<800) + if metric_name == "$\\mathcal{L}_2$": + self.metric = lambda y: np.sqrt(trapezoid(x= self.times[mask],y=y[mask]**2)) + else: + self.metric = lambda y: np.max(np.abs(y[mask])) + + + + + def __repr__(self) -> str: + return f"Surrogte_Benchmarker(name={self.name}, model_dir={self.model_dir})" + + def load_filters(self, filters: list[str]): + self.filters = [] + for filter in filters: + try: + self.filters.append(utils.Filter(filter)) + except: + raise Exception(f"Filter {filter} not available.") + + def get_test_data(self, n_test_data): + test_X_raw = np.empty((n_test_data, len(self.parameter_names))) + test_y_raw = {filter.name: np.empty((n_test_data, len(self.times))) for filter in self.filters} + prediction_y_raw = {filter.name: np.empty((n_test_data, len(self.times))) for filter in self.filters} + + print(f"Determining test data for {n_test_data} random points within parameter grid.") + for j in tqdm.tqdm(range(n_test_data)): + test_X_raw[j] = np.random.uniform(low = self.parameter_boundaries[0], high = self.parameter_boundaries[1]) + param_dict = {name: x for name, x in zip(self.parameter_names, test_X_raw[j])} + + prediction = self.model.predict(param_dict) + + for filt in self.filters: + param_dict["nu"] = filt.nu + prediction_y_raw[filt.name][j] = prediction[filt.name] + mJys = self._call_afterglowpy(param_dict) + test_y_raw[filt.name][j] = conversions.mJys_to_mag_np(mJys) + + self.test_X_raw = test_X_raw + self.test_y_raw = test_y_raw + self.prediction_y_raw = prediction_y_raw + self.n_test_data = n_test_data + + #for saving + test_saver = {"test_"+key: test_y_raw[key] for key in test_y_raw.keys()} + np.savez(os.path.join(self.model_dir, "raw_data_test.npz"), X = test_X_raw, **test_saver) + + def load_test_data(self, ): + + test_data = np.load(self.model_dir+"/raw_data_test.npz") + self.test_X_raw = test_data["X"] + self.test_y_raw = {filt.name: test_data["test_"+filt.name] for filt in self.filters} + self.n_test_data = len(self.test_X_raw) + + self.prediction_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} + for j, X in enumerate(self.test_X_raw): + param_dict = {name: x for name, x in zip(self.parameter_names, X)} + prediction = self.model.predict(param_dict) + for filt in self.filters: + self.prediction_y_raw[filt.name][j] = prediction[filt.name] + + def _call_afterglowpy(self, + params_dict: dict[str, Float]) -> Float[Array, "n_times"]: + """ + Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. + The output will be a set of mJys. + + Args: + Float[Array, "n_times"]: The flux density in mJys at the given times. + """ + + # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z + Z = {} + + Z["jetType"] = params_dict.get("jetType", self.jet_type) + Z["specType"] = params_dict.get("specType", 0) + Z["z"] = params_dict.get("z", 0.0) + Z["xi_N"] = params_dict.get("xi_N", 1.0) + + Z["E0"] = 10 ** params_dict["log10_E0"] + Z["thetaCore"] = params_dict["thetaCore"] + Z["n0"] = 10 ** params_dict["log10_n0"] + Z["p"] = params_dict["p"] + Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] + Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] + Z["d_L"] = 3.086e19 # fix at 10 pc, so that AB magnitude equals absolute magnitude + if "inclination_EM" in list(params_dict.keys()): + Z["thetaObs"] = params_dict["inclination_EM"] + else: + Z["thetaObs"] = params_dict["thetaObs"] + if self.jet_type == 1 or self.jet_type == 4: + Z["b"] = params_dict["b"] + if "thetaWing" in list(params_dict.keys()): + Z["thetaWing"] = params_dict["thetaWing"] + + # Afterglowpy returns flux in mJys + mJys = grb.fluxDensity(self._times_afterglowpy, params_dict["nu"], **Z) + return mJys + + def calculate_mismatch(self, filter): + mismatch = np.empty(self.n_test_data) + + for j in range(self.n_test_data): + mismatch[j] = self.metric(self.prediction_y_raw[filter][j] - self.test_y_raw[filter][j]) + + return mismatch + + + + def plot_lightcurves_mismatch(self, + filter: str, + parameter_labels: list[str] = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"] + ): + if self.metric_name == "$\\mathcal{L}_2$": + bins = np.arange(0, 100, 5) + vmin, vmax = 0, 50 + vline = np.sqrt(trapezoid(x = self.times, y = np.ones(len(self.times)))) + else: + bins = np.arange(0, 3, 0.5) + vmin, vmax = 0, 3 + vline = 1. + + mismatch = self.calculate_mismatch(filter) + + cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) + colored_mismatch = cmap(mismatch/vmax) + + label_dic = {p: label for p, label in zip(self.parameter_names, parameter_labels)} + + fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1) + fig.suptitle(f"{filter}: {self.metric_name} norm") + + for j, p in enumerate(self.parameter_names[1:]): + for k, pp in enumerate(self.parameter_names[:j+1]): + sort = np.argsort(mismatch) + + ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1) + + ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max())) + ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max())) + + + if k!=0: + ax[j,k].set_yticklabels([]) + + if j!=len(self.parameter_names)-2: + ax[j,k].set_xticklabels([]) + + ax[-1,k].set_xlabel(label_dic[pp]) + ax[j,0].set_ylabel(label_dic[p]) + + for cax in ax[j, j+1:]: + cax.set_axis_off() + + ax[0,-1].set_axis_on() + ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,) + ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed") + ax[0,-1].set_yticks([]) + + fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1]) + return fig, ax + + def print_correlations(self, + filter: str,): + + mismatch = self.calculate_mismatch(filter) + + + print(f"\n \n \nCorrelations for filter {filter}:\n") + corrcoeff = [] + for j, p in enumerate(self.parameter_names): + print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], mismatch)[0,1]}") + + def plot_worst_lightcurve(self,filter): + + mismatch = self.calculate_mismatch(filter) + ind = np.argsort(mismatch)[-1] + + fig, ax = plt.subplots(1,1) + fig.suptitle(f"{filter}") + ax.plot(self.times, self.prediction_y_raw[filter][ind], color = "blue") + ax.fill_between(self.times, self.prediction_y_raw[filter][ind]-1, self.prediction_y_raw[filter][ind]+1, color = "blue", alpha = 0.2) + ax.plot(self.times, self.test_y_raw[filter][ind], color = "red") + plt.gca().invert_yaxis() + ax.set(xlabel = "$t$ in days", ylabel = "mag") + + return fig, ax diff --git a/src/fiesta/train/SurrogateTrainer.py b/src/fiesta/train/SurrogateTrainer.py index 96a63b5..6c93ca9 100644 --- a/src/fiesta/train/SurrogateTrainer.py +++ b/src/fiesta/train/SurrogateTrainer.py @@ -2,22 +2,22 @@ import os import numpy as np + import jax import jax.numpy as jnp from jaxtyping import Array, Float, Int -from typing import Callable -from beartype import beartype as typechecker -import tqdm from fiesta.utils import MinMaxScalerJax -from sklearn.model_selection import train_test_split - from fiesta import utils from fiesta import conversions from fiesta.constants import days_to_seconds, c from fiesta import models_utilities import fiesta.train.neuralnets as fiesta_nn + import matplotlib.pyplot as plt -import joblib +import pickle +from typing import Callable +import tqdm + import afterglowpy as grb class SurrogateTrainer: @@ -60,19 +60,14 @@ def __init__(self, self.parameter_names = None self.validation_fraction = validation_fraction - self.preprocessing_metadata = {"X_scaler_min": {}, - "X_scaler_max": {}, - "y_scaler_min": {}, - "y_scaler_max": {}} + self.preprocessing_metadata = {} self.X_raw = None self.y_raw = None self.X = None self.y = None - - self.trained_states = None - + def __repr__(self) -> str: return f"SurrogateTrainer(name={self.name})" @@ -86,14 +81,14 @@ def preprocess(self): self.y = {} for filt in self.filters: y_scaler = MinMaxScalerJax() - self.y[filt] = y_scaler.fit_transform(self.y_raw[filt]) + self.y[filt.name] = y_scaler.fit_transform(self.y_raw[filt.name]) self.y_scalers[filt] = y_scaler # Save the metadata self.preprocessing_metadata["X_scaler_min"] = self.X_scaler.min_val self.preprocessing_metadata["X_scaler_max"] = self.X_scaler.max_val - self.preprocessing_metadata["y_scaler_min"] = {filt: self.y_scalers[filt].min_val for filt in self.filters} - self.preprocessing_metadata["y_scaler_max"] = {filt: self.y_scalers[filt].max_val for filt in self.filters} + self.preprocessing_metadata["y_scaler_min"] = {filt.name: self.y_scalers[filt.name].min_val for filt in self.filters} + self.preprocessing_metadata["y_scaler_max"] = {filt.name: self.y_scalers[filt.name].max_val for filt in self.filters} print("Preprocessing data . . . done") def fit(self, @@ -113,20 +108,17 @@ def fit(self, self.config = config trained_states = {} - X = jnp.array(self.X) + input_ndim = len(self.parameter_names) for filt in self.filters: - # Fetch the output data of this filter, and perform train-validation split on it - y = jnp.array(self.y[filt]) - train_X, val_X, train_y, val_y = train_test_split(X, y, test_size=self.validation_fraction) - + # Create neural network and initialize the state net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) key, subkey = jax.random.split(key) state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) # Perform training loop - state, train_losses, val_losses = fiesta_nn.train_loop(state, config, train_X, train_y, val_X, val_y, verbose=verbose) + state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) # Plot and save the plot if so desired if self.plots_dir is not None: @@ -140,10 +132,10 @@ def fit(self, plt.ylabel("MSE loss") plt.yscale('log') plt.title("Learning curves") - plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{filt}.png")) + plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{filt.name}.png")) plt.close() - trained_states[filt] = state + trained_states[filt.name] = state self.trained_states = trained_states @@ -151,25 +143,43 @@ def save(self): """ Save the trained model and all the used metadata to the outdir. """ - if not os.path.exists(self.outdir): os.makedirs(self.outdir) - + + meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl") + + if os.path.exists(meta_filename): + with open(meta_filename, "rb") as meta_file: + save = pickle.load(meta_file) + if not np.array_equal(save["times"], self.times): # check whether the metadata from previously trained filters agrees + raise Exception(f"The time array needs to coincide with the time array for previous filters: {save['times']}") + if not np.array_equal(save["parameter_names"], self.parameter_names): + raise Exception(f"The parameters need to coincide with the parameters for previous filters: {save['parameter_names']}") + else: + save = {} + + save["times"] = self.times + save["parameter_names"] = self.parameter_names + # TODO: see if we can save the jet_type here somewhat more self-consistently + for filt in self.filters: - model = self.trained_states[filt] - fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt}.pkl") - - # TODO: improve saving of the scalers: saving the objects is not the best way to do it and breaks pickle - joblib.dump(self.preprocessing_metadata, self.outdir + f"{self.name}.joblib") + model = self.trained_states[filt.name] + fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl") + save[filt.name] = self.preprocessing_metadata[filt.name] + + with open(meta_filename, "wb") as meta_file: + pickle.dump(save, meta_file) def _save_raw_data(self): print("Saving raw data . . .") - np.savez(os.path.join(self.outdir, "raw_data.npz"), X_raw=self.X_raw, **self.y_raw) + np.savez(os.path.join(self.outdir, "raw_data_training.npz"), X_raw=self.train_X_raw, **self.train_y_raw) + np.savez(os.path.join(self.outdir, "raw_data_validation.npz"), X_raw=self.val_X_raw, **self.val_y_raw) print("Saving raw data . . . done") def _save_preprocessed_data(self): print("Saving preprocessed data . . .") - np.savez(os.path.join(self.outdir, "preprocessed_data.npz"), X=self.X, **self.y) + np.savez(os.path.join(self.outdir, "preprocessed_data_training.npz"), X=self.train_X, **self.train_y) + np.savez(os.path.join(self.outdir, "preprocessed_data_validation.npz"), X=self.val_X, **self.val_y) print("Saving preprocessed data . . . done") class SVDSurrogateTrainer(SurrogateTrainer): @@ -208,12 +218,14 @@ def __init__(self, lc_dir (list[str]): Directory where all the raw light curve files, to be read and processed into a surrogate model. outdir (str): Directory where the trained surrogate model has to be saved. filters (list[str], optional): List of all the filters used in the light curve files and for which surrogate has to be trained. If None, all the filters will be used. Defaults to None. + svd_ncoeff: int : Number of SVD coefficients to use in data reduction during training. Defaults to 10. validation_fraction (Float, optional): Fraction of the data to be used for validation. Defaults to 0.2. - tmin (Float, optional): Minimum time of the light curve, all data before is discarded. Defaults to 0.05. - tmax (Float, optional): Maximum time of the light curve, all data after is discarded. Defaults to 14.0. + tmin (Float, optional): Minimum time in days of the light curve, all data before is discarded. Defaults to 0.05. + tmax (Float, optional): Maximum time in days of the light curve, all data after is discarded. Defaults to 14.0. dt (Float, optional): Time step in the light curve. Defaults to 0.1. plots_dir (str, optional): Directory where the plots of the training process will be saved. Defaults to None, which means no plots will be generated. save_raw_data (bool, optional): If True, the raw data will be saved in the outdir. Defaults to False. + save_preprocessed_data: If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. """ super().__init__(name=name, outdir=outdir, validation_fraction=validation_fraction) @@ -234,24 +246,14 @@ def __init__(self, self.load_times() self.load_parameter_names() - self.initialize_metadata() self.load_raw_data() + self.preprocess() - + if save_preprocessed_data: self._save_preprocessed_data() if save_raw_data: self._save_raw_data() - - # TODO: more elegant way to do this? - def initialize_metadata(self): - self.preprocessing_metadata = {"X_scaler_min": {}, - "X_scaler_max": {}, - "y_scaler_min": {}, - "y_scaler_max": {}, - "times": self.times, - "VA": {}, - "svd_ncoeff": self.svd_ncoeff} def load_parameter_names(self): raise NotImplementedError @@ -264,34 +266,32 @@ def load_filters(self, filters: list[str] = None): def load_raw_data(self): raise NotImplementedError - + def preprocess(self): """ Preprocess the data. This includes scaling the inputs and outputs, performing SVD decomposition, and saving the necessary metadata for later use. """ # Scale inputs X_scaler = MinMaxScalerJax() - X = X_scaler.fit_transform(self.X_raw) - - self.preprocessing_metadata["X_scaler_min"] = X_scaler.min_val - self.preprocessing_metadata["X_scaler_max"] = X_scaler.max_val - + self.train_X = X_scaler.fit_transform(self.train_X_raw) # fit the scaler to the training data + self.val_X = X_scaler.transform(self.val_X_raw) # transform the val data + # Scale outputs, do SVD and save into y - y = {filt: [] for filt in self.filters} + self.train_y = {filt.name: [] for filt in self.filters} + self.val_y = {filt.name: [] for filt in self.filters} + + print(f"Rescaling the training and validation data for filters {[filter.name for filter in self.filters]}") for filt in tqdm.tqdm(self.filters): - data_scaler = MinMaxScalerJax() - data = data_scaler.fit_transform(self.y_raw[filt]) - self.preprocessing_metadata["y_scaler_min"][filt] = data_scaler.min_val - self.preprocessing_metadata["y_scaler_max"][filt] = data_scaler.max_val + y_scaler = MinMaxScalerJax() + data = y_scaler.fit_transform(self.train_y_raw[filt.name]) - # Do SVD decomposition + # Do SVD decomposition on the training data UA, _, VA = np.linalg.svd(data, full_matrices=True) VA = VA.T n, n = UA.shape m, m = VA.shape - # This is taken over from NMMA cAmat = np.zeros((self.svd_ncoeff, n)) cAvar = np.zeros((self.svd_ncoeff, n)) @@ -307,13 +307,22 @@ def preprocess(self): np.dot(np.diag(np.power(errors, 2.0)), VA[:, : self.svd_ncoeff]), ) ) - self.preprocessing_metadata["VA"][filt] = VA - - # Transpose to get the shape (n_batch, n_svd_coeff) - y[filt] = cAmat.T + + self.train_y[filt.name] = cAmat.T # Transpose to get the shape (n_batch, n_svd_coeff) + + # Do SVD decomposition on the validation data + val_data = y_scaler.transform(self.val_y_raw[filt.name]) + cAmat = np.zeros((self.svd_ncoeff, self.n_val_data)) + for i in range(self.n_val_data): + cAmat[:,i] = np.dot( + val_data[i,:], VA[:, : self.svd_ncoeff] + ) - self.X = X - self.y = y + self.val_y[filt.name] = cAmat.T # Transpose to get the shape (n_val, n_svd_coeff) + + #Save the scalers + self.preprocessing_metadata[filt.name] = {"VA": VA, "X_scaler_max": X_scaler.max_val, "X_scaler_min": X_scaler.min_val, "y_scaler": y_scaler, "svd_ncoeff": self.svd_ncoeff} + def __repr__(self) -> str: return f"SVDSurrogateTrainer(name={self.name}, lc_dir={self.lc_dir}, outdir={self.outdir}, filters={self.filters})" @@ -430,16 +439,30 @@ def _read_files(self) -> tuple[dict[str, Float[Array, " n_batch n_params"]], Flo def load_raw_data(self): print("Reading data files and interpolating NaNs . . .") - self.X_raw, y = self._read_files() - self.y_raw = utils.interpolate_nans(y, self._times_grid, self.times) - if self.save_raw_data: - np.savez(os.path.join(self.outdir, "raw_data.npz"), X_raw=self.X_raw, times=self.times, times_grid=self._times_grid, **self.y_raw) + X_raw, y = self._read_files() + y_raw = utils.interpolate_nans(y, self._times_grid, self.times) + ###if self.save_raw_data: + ### np.savez(os.path.join(self.outdir, "raw_data.npz"), X_raw=X_raw, times=self.times, times_grid=self._times_grid, **y_raw) + + # split here into training and validating data + self.n_val_data = int(self.validation_fraction*len(X_raw)) + self.n_training_data = len(X_raw) - self.n_val_data + mask = np.zeros(len(X_raw) ,dtype = bool) + mask[np.random.choice(len(X_raw), self.n_val, replace = False)] = True + + self.train_X_raw, self.val_X_raw = X_raw[~mask], X_raw[mask] + self.train_y_raw, self.val_y_raw = {}, {} + + for filters in self.filters: + self.train_y_raw[filt.name] = y_raw[filt.name][~mask] + self.val_y_raw[filt.name] = y_raw[filt.name][mask] + # TODO: perhaps rename to *_1D, since it is only for 1D light curves, and we might want to get support for 2D by incorporating the frequencies... Unsure about the approach here class AfterglowpyTrainer(SVDSurrogateTrainer): - prior_ranges: dict[str, list[Float, Float]] + parameter_grid: dict[str, list[Float]] n_training_data: Int fixed_parameters: dict[str, Float] jet_type: Int @@ -452,8 +475,8 @@ def __init__(self, name: str, outdir: str, filters: list[str], - prior_ranges: dict[str, list[Float, Float]], - n_training_data: Int = 10_000, + parameter_grid: dict[str, list[float]], + n_training_data: Int = 5000, jet_type: Int = -1, fixed_parameters: dict[str, Float] = {}, tmin: Float = 0.1, @@ -465,29 +488,45 @@ def __init__(self, svd_ncoeff: Int = 10, save_raw_data: bool = False, save_preprocessed_data: bool = False, + remake_training_data = False, ): """ - TODO: add documentation Initialize the surrogate model trainer. The initialization also takes care of reading data and preprocessing it, but does not automatically fit the model. Users may want to inspect the data before fitting the model. Args: name (str): Name given to the model outdir (str): Output directory to save the trained model - prior_ranges (dict[str, list[Float, Float]]): Dictionary containing the prior ranges for each parameter, i.e., the range on which the surrogate must be trained. The keys should be the parameter names and the values should be a list containing the minimum and maximum values of the prior range. NOTE: frequency (nu) should be included, or given to fixed parameters. + parameter_grid (dict[str, list[Float]]): Dictionary containing the grid points for each parameter, i.e., the parameter values on which the surrogate will be trained. The keys should be the parameter names and the values should be a list.. + jet_type (Int): Type of jet for the afterglowpy, -1 is tophat, 0 is Gaussian, 4 is PowerLaw + fixed_parameters (dict[str, Float]) : values of the afterglowpy parameters that should be kept fixed for the surrogate model + tmin (Float, optional): Minimum time in days of the light curve, all data before is discarded. Defaults to 0.1. + tmax (Float, optional): Maximum time in days of the light curve, all data after is discarded. Defaults to 1000. + n_times: number of time nodes for the training light curve data + use_log_spacing: bool : whether the time nodes of the training light curve data should be log10 spaced + validation_fraction (Float, optional): Fraction of the data to be used for validation. Defaults to 0.2. + plots_dir : str : outdir for the plots + svd_ncoeff: int : Number of SVD coefficients to use in data reduction during training. Defaults to 10. + save_raw_data (bool, optional): If True, the raw data will be saved in the outdir. Defaults to False. + save_preprocessed_data: If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. """ + self.n_times = n_times dt = (tmax - tmin) / n_times - self.prior_ranges = prior_ranges - self.n_training_data = n_training_data + self.parameter_grid = parameter_grid self.fixed_parameters = fixed_parameters self.use_log_spacing = use_log_spacing - # Check jet type before saving # TODO: need to check if supported ids are correct? - supported_jet_types = [-1, 0, 1, 4] + # Check jet type before saving + supported_jet_types = [-1, 0, 4] if jet_type not in supported_jet_types: raise ValueError(f"Jet type {jet_type} is not supported. Supported jet types are: {supported_jet_types}") self.jet_type = jet_type + self.remake_training_data = remake_training_data + + self.n_training_data = n_training_data + self.validation_fraction = validation_fraction + self.n_val_data = int(self.n_training_data * self.validation_fraction/(1-self.validation_fraction)) super().__init__(name=name, outdir=outdir, @@ -501,26 +540,14 @@ def __init__(self, save_raw_data=save_raw_data, save_preprocessed_data=save_preprocessed_data) - - def initialize_metadata(self): - self.preprocessing_metadata = {"X_scaler_min": {}, - "X_scaler_max": {}, - "y_scaler_min": {}, - "y_scaler_max": {}, - "times": self.times, - "VA": {}, - "svd_ncoeff": self.svd_ncoeff, - "nus": self.nus, - "jet_type": self.jet_type, - "parameter_names": self.parameter_names} - def load_filters(self, filters: list[str]): - self.filters = filters - filts, lambdas = utils.get_default_filts_lambdas(filters) - self.filters = filts - nus = c / lambdas - self.nus = dict(zip(filts, nus)) - + self.filters = [] + for filter in filters: + try: + self.filters.append(utils.Filter(filter)) + except: + raise Exception(f"Filter {filter} not available.") + def load_times(self): if self.use_log_spacing: times = np.logspace(np.log10(self.tmin), np.log10(self.tmax), num=self.n_times) @@ -530,38 +557,83 @@ def load_times(self): self._times_afterglowpy = self.times * days_to_seconds # afterglowpy takes seconds as input def load_parameter_names(self): - self.parameter_names = list(self.prior_ranges.keys()) - + self.parameter_names = list(self.parameter_grid.keys()) + def load_raw_data(self): + data_files_exist = os.path.exists(self.outdir+"/raw_data_training.npz") and os.path.exists(self.outdir+"/raw_data_validation.npz") + if data_files_exist and not self.remake_training_data: + self.train_X_raw, self.train_y_raw, self.val_X_raw, self.val_y_raw = self._read_files() + else: + self.create_raw_data() + + def create_raw_data(self): """ - Create a grid of training data with specified settings and generate the output files for them. + Create a grid of training data with specified settings and generate the output files for them. + TODO: for now we train per filter, but best to change this! """ - - # Initialize the output values + # Create training data X_raw = np.empty((self.n_training_data, len(self.parameter_names))) - y_raw = {filt: np.empty((self.n_training_data, len(self.times))) for filt in self.filters} + y_raw = {filt.name: np.empty((self.n_training_data, len(self.times))) for filt in self.filters} - parameter_names = list(self.prior_ranges.keys()) - - print("Creating the afterglowpy dataset . . .") + for j, key in enumerate(self.parameter_grid.keys()): + X_raw[:,j] = np.random.choice(self.parameter_grid[key], size = self.n_training_data, replace = True) + + + print(f"Creating the afterglowpy training dataset on grid with {self.n_training_data} points.") for i in tqdm.tqdm(range(self.n_training_data)): - # Generate "intrinsic" parameter values by random sampling: - param_values = [np.random.uniform(self.prior_ranges[p][0], self.prior_ranges[p][1]) for p in parameter_names] - X_raw[i] = param_values - - # Add nu per filter before calling afterglowpy for filt in self.filters: - param_dict = dict(zip(parameter_names, param_values)) + param_dict = dict(zip(self.parameter_names, X_raw[i])) param_dict.update(self.fixed_parameters) - param_dict["nu"] = self.nus[filt] + param_dict["nu"] = filt.nu # Add nu per filter before calling afterglowpy # Create and save output mJys = self._call_afterglowpy(param_dict) - y_raw[filt][i] = conversions.mJys_to_mag_np(mJys) + y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys) + + self.train_X_raw = X_raw + self.train_y_raw = y_raw + + + # Create validation data + X_raw = np.empty((self.n_val_data, len(self.parameter_names))) + y_raw = {filt.name: np.empty((self.n_val_data, len(self.times))) for filt in self.filters} + + print(f"Creating the afterglowpy validation dataset on {self.n_val_data} random points within grid.") + for i in tqdm.tqdm(range(self.n_val_data)): + X_raw[i] = [np.random.uniform(self.parameter_grid[p][0], self.parameter_grid[p][-1]) for p in self.parameter_names] + + for filt in self.filters: + param_dict = dict(zip(self.parameter_names, X_raw[i])) + param_dict.update(self.fixed_parameters) + param_dict["nu"] = filt.nu # Add nu per filter before calling afterglowpy - self.X_raw = X_raw - self.y_raw = y_raw + # Create and save output + mJys = self._call_afterglowpy(param_dict) + y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys) + + self.val_X_raw = X_raw + self.val_y_raw = y_raw + + + + def _read_files(self,): + raw_data_train = np.load(self.outdir+"/raw_data_training.npz") + raw_data_validation = np.load(self.outdir+'/raw_data_validation.npz') + + self.n_training_data = 4000 + self.n_val_data = 1000 + + training_y_raw = {} + val_y_raw = {} + + select = np.random.choice(range(0, len(raw_data_train["X_raw"])), size = self.n_training_data, replace = False) + select_val = np.random.choice(range(0, len(raw_data_validation["X_raw"])), size = self.n_val_data, replace = False) + + for filt in self.filters: + training_y_raw[filt.name] = raw_data_train[filt.name][select] + val_y_raw[filt.name] = raw_data_validation[filt.name][select_val] + return raw_data_train["X_raw"][select], training_y_raw, raw_data_validation["X_raw"][select_val], val_y_raw def _call_afterglowpy(self, diff --git a/src/fiesta/utils.py b/src/fiesta/utils.py index b177555..643adc3 100644 --- a/src/fiesta/utils.py +++ b/src/fiesta/utils.py @@ -7,7 +7,6 @@ import scipy.interpolate as interp import copy import re -from sncosmo.bandpasses import _BANDPASSES from astropy.time import Time from fiesta.constants import pc_to_cm import astropy @@ -233,6 +232,41 @@ def load_event_data(filename): return data +######################### +### Filters ### +######################### + + +class Filter: + + def __init__(self, + name: str,): + self.name = name + if (self.name, None) in _BANDPASSES._primary_loaders: + bandpass = sncosmo.get_bandpass(self.name) + self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10) + elif (self.name, None) in _BANDPASS_INTERPOLATORS._primary_loaders: + bandpass = sncosmo.get_bandpass(val["name"], 3) + self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10) + elif self.name.endswith("GHz"): + freq = re.findall(r"[-+]?(?:\d*\.*\d+)", self.name.replace("-","")) + freq = float(freq[-1]) + self.nu = freq*1e9 + elif self.name.endswith("keV"): + energy = re.findall(r"[-+]?(?:\d*\.*\d+)", self.name.replace("-","")) + energy = float(energy[-1]) + self.nu = energy*1000*scipy.constants.eV / scipy.constants.h + else: + raise Excepetion(f"Filter {self.name} not available.") + + self.wavelength = scipy.constants.c/self.nu + + + + + + + def get_all_bandpass_metadata(): # TODO: taken over from NMMA, improve """ @@ -310,6 +344,7 @@ def get_default_filts_lambdas(filters: list[str]=None): if filters is not None: filts_slice = [] lambdas_slice = [] + transmittance_slice = [] for filt in filters: if filt.startswith("radio") and filt not in filts: @@ -324,7 +359,9 @@ def get_default_filts_lambdas(filters: list[str]=None): freq = freq.to("Hz").value # adding to the list filts_slice.append(filt) - lambdas_slice.append(scipy.constants.c / freq) + lambdas_slice.append([scipy.constants.c / freq]) + transmittance_slice.append([1]) + elif filt.startswith("X-ray-") and filt not in filts: # for additional X-ray filters that not in the list # calculate the lambdas based on the filter name @@ -337,12 +374,14 @@ def get_default_filts_lambdas(filters: list[str]=None): freq = energy.to("eV").value * scipy.constants.eV / scipy.constants.h # adding to the list filts_slice.append(filt) - lambdas_slice.append(scipy.constants.c / freq) + lambdas_slice.append([scipy.constants.c / freq]) + transmittance_slice.append([1]) + else: try: ii = filts.index(filt) filts_slice.append(filts[ii]) - lambdas_slice.append(lambdas[ii]) + lambdas_slice.append([lambdas[ii]]) except ValueError: ii = filts.index(filt.replace("_", ":")) filts_slice.append(filts[ii].replace(":", "_")) @@ -351,7 +390,8 @@ def get_default_filts_lambdas(filters: list[str]=None): filts = filts_slice lambdas = np.array(lambdas_slice) - return filts, lambdas + return filts, lambdas, transmittance + def mJys_to_mag(): pass \ No newline at end of file diff --git a/trained_models/afterglowpy/tophat/X-ray-1keV.pkl b/trained_models/afterglowpy/tophat/X-ray-1keV.pkl new file mode 100644 index 0000000..4356bf1 Binary files /dev/null and b/trained_models/afterglowpy/tophat/X-ray-1keV.pkl differ diff --git a/trained_models/afterglowpy/tophat/bessellv.pkl b/trained_models/afterglowpy/tophat/bessellv.pkl new file mode 100644 index 0000000..79ef2ee Binary files /dev/null and b/trained_models/afterglowpy/tophat/bessellv.pkl differ diff --git a/trained_models/afterglowpy/tophat/preprocessed_data_training.npz b/trained_models/afterglowpy/tophat/preprocessed_data_training.npz new file mode 100644 index 0000000..55d016a Binary files /dev/null and b/trained_models/afterglowpy/tophat/preprocessed_data_training.npz differ diff --git a/trained_models/afterglowpy/tophat/preprocessed_data_validation.npz b/trained_models/afterglowpy/tophat/preprocessed_data_validation.npz new file mode 100644 index 0000000..94759e7 Binary files /dev/null and b/trained_models/afterglowpy/tophat/preprocessed_data_validation.npz differ diff --git a/trained_models/afterglowpy/tophat/radio-3GHz.pkl b/trained_models/afterglowpy/tophat/radio-3GHz.pkl index 95a76ef..3863868 100644 Binary files a/trained_models/afterglowpy/tophat/radio-3GHz.pkl and b/trained_models/afterglowpy/tophat/radio-3GHz.pkl differ diff --git a/trained_models/afterglowpy/tophat/radio-6GHz.pkl b/trained_models/afterglowpy/tophat/radio-6GHz.pkl new file mode 100644 index 0000000..02caac0 Binary files /dev/null and b/trained_models/afterglowpy/tophat/radio-6GHz.pkl differ diff --git a/trained_models/afterglowpy/tophat/raw_data_test.npz b/trained_models/afterglowpy/tophat/raw_data_test.npz new file mode 100644 index 0000000..47f0d81 Binary files /dev/null and b/trained_models/afterglowpy/tophat/raw_data_test.npz differ diff --git a/trained_models/afterglowpy/tophat/raw_data_training.npz b/trained_models/afterglowpy/tophat/raw_data_training.npz new file mode 100644 index 0000000..3e736f7 Binary files /dev/null and b/trained_models/afterglowpy/tophat/raw_data_training.npz differ diff --git a/trained_models/afterglowpy/tophat/raw_data_validation.npz b/trained_models/afterglowpy/tophat/raw_data_validation.npz new file mode 100644 index 0000000..2a33658 Binary files /dev/null and b/trained_models/afterglowpy/tophat/raw_data_validation.npz differ diff --git a/trained_models/afterglowpy/tophat/tophat.joblib b/trained_models/afterglowpy/tophat/tophat.joblib deleted file mode 100644 index 2346594..0000000 Binary files a/trained_models/afterglowpy/tophat/tophat.joblib and /dev/null differ diff --git a/trained_models/afterglowpy/tophat/tophat_metadata.pkl b/trained_models/afterglowpy/tophat/tophat_metadata.pkl new file mode 100644 index 0000000..f30e6b8 Binary files /dev/null and b/trained_models/afterglowpy/tophat/tophat_metadata.pkl differ diff --git a/trained_models/benchmark_afterglowpy_tophat.py b/trained_models/benchmark_afterglowpy_tophat.py new file mode 100644 index 0000000..7cab84a --- /dev/null +++ b/trained_models/benchmark_afterglowpy_tophat.py @@ -0,0 +1,42 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.utils import Filter + + +name = "tophat" +model_dir = f"./afterglowpy/{name}/" +FILTERS = ["radio-6GHz", "radio-3GHz"]#["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + +for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: + if metric_name == "$\\mathcal{L}_2$": + file_ending = "L2" + else: + file_ending = "Linf" + + + B = Benchmarker(name = "tophat", + model_dir = model_dir, + filters = FILTERS, + n_test_data = 2000, + metric_name = metric_name, + remake_test_data = True, + jet_type = -1, + ) + + for filt in FILTERS: + + fig, ax = B.plot_lightcurves_mismatch(filter =filt) + fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + + B.print_correlations(filter = filt) + + + fig, ax = B.plot_worst_lightcurve(filter = filt) + fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200) + + + + diff --git a/trained_models/train_afterglowpy_tophat.py b/trained_models/train_afterglowpy_tophat.py index 6e86678..3af5a8b 100644 --- a/trained_models/train_afterglowpy_tophat.py +++ b/trained_models/train_afterglowpy_tophat.py @@ -4,31 +4,58 @@ from fiesta.train.SurrogateTrainer import AfterglowpyTrainer from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel from fiesta.train.neuralnets import NeuralnetConfig -from fiesta.utils import get_default_filts_lambdas +from fiesta.utils import Filter ############# ### SETUP ### ############# -FILTERS = ["radio-3GHz"] # TODO: add the filters [radio-6GHz, X-ray-1keV] -FILTERS, lambdas = get_default_filts_lambdas(FILTERS) -nus = c / lambdas -print(FILTERS) -print(nus) +FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] +for filter in FILTERS: + filter = Filter(filter) + print(filter.name, filter.nu) tmin = 1 tmax = 1000 -prior_ranges = { - 'inclination_EM': [0.0, np.pi/2], - 'log10_E0': [47.0, 57.0], - 'thetaCore': [0.01, np.pi/10], - 'log10_n0': [-6, 3.0], - 'p': [2.01, 3.0], - 'log10_epsilon_e': [-5, 0], - 'log10_epsilon_B': [-10, 0] +""" +#grid for radio-6GHz and radio-3GHz +parameter_grid = { + 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], + 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], + 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], + 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], + 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], + 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], + 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] } +#grid for X-ray-1keV and bessellv +parameter_grid = { + 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], + 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], + 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], + 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], + 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], + 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], + 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] +} + + +""" + +FILTERS = ["radio-3GHz", "radio-6GHz"] +parameter_grid = { + 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], + 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], + 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], + 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], + 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], + 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], + 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] +} + + jet_name = "tophat" jet_conversion = {"tophat": -1, "gaussian": 0, @@ -46,15 +73,16 @@ trainer = AfterglowpyTrainer(name, outdir, FILTERS, - prior_ranges, - n_training_data= 2_000, + parameter_grid, jet_type = jet_conversion[jet_name], tmin = tmin, tmax = tmax, plots_dir="./figures/", - svd_ncoeff=10, + svd_ncoeff=40, save_raw_data=True, - save_preprocessed_data=True + save_preprocessed_data=True, + remake_training_data = True, + n_training_data = 7000 ) ############### @@ -62,8 +90,9 @@ ############### config = NeuralnetConfig(output_size=trainer.svd_ncoeff, - nb_epochs=10_000, - hidden_layer_sizes = [128, 256, 128]) + nb_epochs=50_000, + hidden_layer_sizes = [64, 128, 64], + learning_rate = 8e-3) trainer.fit(config=config) trainer.save() @@ -79,8 +108,8 @@ filters = FILTERS) for filt in lc_model.filters: - X_example = trainer.X_raw[0] - y_raw = trainer.y_raw[filt][0] + X_example = trainer.val_X_raw[0] + y_raw = trainer.val_y_raw[filt][0] # Turn into a dict: this is how the model expects the input X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} @@ -99,5 +128,4 @@ plt.gca().invert_yaxis() plt.savefig(f"./figures/afterglowpy_{name}_{filt}_example.png") - plt.close() - break # only show first filter \ No newline at end of file + plt.close() \ No newline at end of file