diff --git a/benchmarks/GRB/benchmark_afterglowpy_tophat.py b/benchmarks/GRB/benchmark_afterglowpy_tophat.py index 68137ea..f2521ff 100644 --- a/benchmarks/GRB/benchmark_afterglowpy_tophat.py +++ b/benchmarks/GRB/benchmark_afterglowpy_tophat.py @@ -10,29 +10,40 @@ model_dir = f"../../trained_models/GRB/afterglowpy/{name}/" FILTERS = ["radio-6GHz", "radio-3GHz"]#["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: if metric_name == "$\\mathcal{L}_2$": file_ending = "L2" else: file_ending = "Linf" - - B = Benchmarker(name = "tophat", + B = Benchmarker(name = name, + parameter_grid = parameter_grid, model_dir = model_dir, + MODEL = AfterglowpyLightcurvemodel, filters = FILTERS, n_test_data = 2000, metric_name = metric_name, - remake_test_data = True, + remake_test_data = False, jet_type = -1, ) + + fig, ax = B.plot_error_distribution("radio-6GHz") + for filt in FILTERS: - fig, ax = B.plot_lightcurves_mismatch(filter =filt) - fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{core}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) + fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200) B.print_correlations(filter = filt) + + + if metric_name == "$\\mathcal{L}_\infty$": + fig, ax = B.plot_error_distribution(filt) + fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200) fig, ax = B.plot_worst_lightcurve(filter = filt) fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200) + diff --git a/benchmarks/KN/benchmark_Bu2019lm.py b/benchmarks/KN/benchmark_Bu2019lm.py index 612a89f..e7e925c 100644 --- a/benchmarks/KN/benchmark_Bu2019lm.py +++ b/benchmarks/KN/benchmark_Bu2019lm.py @@ -29,11 +29,14 @@ fig.savefig(f"./figures/benchmark_{filt}_{file_ending}.pdf", dpi = 200) B.print_correlations(filter = filt) - - - fig, ax = B.plot_worst_lightcurve(filter = filt) - fig.savefig(f"./figures/worst_lightcurve_{filt}_{file_ending}.pdf", dpi = 200) + if metric_name == "$\\mathcal{L}_\infty$": + fig, ax = B.plot_error_distribution(filt) + fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200) + + + fig, ax = B.plot_worst_lightcurves() + fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200) diff --git a/examples/GRB/injection_gaussian.py b/examples/GRB/injection_gaussian.py new file mode 100644 index 0000000..1ade1f8 --- /dev/null +++ b/examples/GRB/injection_gaussian.py @@ -0,0 +1,213 @@ +"""Injection runs with afterglowpy gaussian""" + +import os +import jax +print(f"GPU found? {jax.devices()}") +import jax.numpy as jnp +jax.config.update("jax_enable_x64", True) +import numpy as np +import matplotlib.pyplot as plt +import corner + +from fiesta.inference.lightcurve_model import AfterglowpyPCA, PCALightcurveModel +from fiesta.inference.injection import InjectionRecoveryAfterglowpy +from fiesta.inference.likelihood import EMLikelihood +from fiesta.inference.prior import Uniform, CompositePrior, Constraint +from fiesta.inference.prior_dict import ConstrainedPrior +from fiesta.inference.fiesta import Fiesta +from fiesta.utils import load_event_data, write_event_data + +import time +start_time = time.time() + +################ +### Preamble ### +################ + +jax.config.update("jax_enable_x64", True) + +params = {"axes.grid": True, + "text.usetex" : True, + "font.family" : "serif", + "ytick.color" : "black", + "xtick.color" : "black", + "axes.labelcolor" : "black", + "axes.edgecolor" : "black", + "font.serif" : ["Computer Modern Serif"], + "xtick.labelsize": 16, + "ytick.labelsize": 16, + "axes.labelsize": 16, + "legend.fontsize": 16, + "legend.title_fontsize": 16, + "figure.titlesize": 16} + +plt.rcParams.update(params) + +default_corner_kwargs = dict(bins=40, + smooth=1., + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + color="blue", + # quantiles=[], + # levels=[0.9], + plot_density=True, + plot_datapoints=False, + fill_contours=True, + max_n_ticks=4, + min_n_ticks=3, + save=False, + truth_color="red") + + +############## +### MODEL ### +############## + +name = "gaussian" +model_dir = f"../../flux_models/afterglowpy_{name}/model" +FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"] + +model = AfterglowpyPCA(name, + model_dir, + filters = FILTERS) + + +################### +### INJECT ### +### AFTERGLOWPY ### +################### + +trigger_time = 58849 # 01-01-2020 in mjd +remake_injection = False +injection_dict = {"inclination_EM": 0.174, "log10_E0": 54.4, "thetaCore": 0.14, "alphaWing": 3, "p": 2.6, "log10_n0": -2, "log10_epsilon_e": -2.06, "log10_epsilon_B": -4.2, "luminosity_distance": 40.0} + +if remake_injection: + injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = 0, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time) + injection.create_injection() + data = injection.data + write_event_data("./injection_gaussian/injection_gaussian.dat", data) + +data = load_event_data("./injection_gaussian/injection_gaussian.dat") +############################# +### PRIORS AND LIKELIHOOD ### +############################# + +inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM']) +log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0']) +thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore']) +alphaWing = Uniform(xmin = 0.2, xmax = 3.5, naming= ["alphaWing"]) +thetaWing = Constraint(xmin = 0, xmax = np.pi/2, naming = ["thetaWing"]) +log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0']) +p = Uniform(xmin=2.01, xmax=3.0, naming=['p']) +log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e']) +log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B']) +epsilon_tot = Constraint(xmin = 0, xmax = 1, naming = ["epsilon_tot"]) + +# luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance']) +def conversion_function(sample): + converted_sample = sample + converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"] + converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"]) + return converted_sample + +prior_list = [inclination_EM, + log10_E0, + thetaCore, + alphaWing, + log10_n0, + p, + log10_epsilon_e, + log10_epsilon_B, + thetaWing, + epsilon_tot] + +prior = ConstrainedPrior(prior_list, conversion_function) + +detection_limit = None +likelihood = EMLikelihood(model, + data, + FILTERS, + tmax = 2000.0, + trigger_time=trigger_time, + detection_limit = detection_limit, + fixed_params={"luminosity_distance": 40.0}, + error_budget = 1e-5) + + +############## +### FIESTA ### +############## + +mass_matrix = jnp.eye(prior.n_dim) +eps = 5e-3 +local_sampler_arg = {"step_size": mass_matrix * eps} + +# Save for postprocessing +outdir = f"./injection_{name}/" +if not os.path.exists(outdir): + os.makedirs(outdir) + +fiesta = Fiesta(likelihood, + prior, + n_chains = 1_000, + n_loop_training = 7, + n_loop_production = 3, + num_layers = 4, + hidden_size = [64, 64], + n_epochs = 20, + n_local_steps = 50, + n_global_steps = 200, + local_sampler_arg=local_sampler_arg, + outdir = outdir) + +fiesta.sample(jax.random.PRNGKey(42)) + +fiesta.print_summary() + +name = outdir + f'results_training.npz' +print(f"Saving samples to {name}") +state = fiesta.Sampler.get_sampler_state(training=True) +chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state[ +"log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, log_prob=log_prob, local_accs=local_accs, + global_accs=global_accs, loss_vals=loss_vals) + +# - production phase +name = outdir + f'results_production.npz' +print(f"Saving samples to {name}") +state = fiesta.Sampler.get_sampler_state(training=False) +chains, log_prob, local_accs, global_accs = state["chains"], state[ + "log_prob"], state["local_accs"], state["global_accs"] +local_accs = jnp.mean(local_accs, axis=0) +global_accs = jnp.mean(global_accs, axis=0) +np.savez(name, chains=chains, log_prob=log_prob, + local_accs=local_accs, global_accs=global_accs) + +################ +### PLOTTING ### +################ +# Fixed names: do not include them in the plotting, as will break corner +parameter_names = prior.naming +truths = [injection_dict[key] for key in parameter_names] + +n_chains, n_steps, n_dim = np.shape(chains) +samples = np.reshape(chains, (n_chains * n_steps, n_dim)) +samples = np.asarray(samples) # convert from jax.numpy array to numpy array for corner consumption + +corner.corner(samples, labels = parameter_names, hist_kwargs={'density': True}, truths = truths, **default_corner_kwargs) +plt.savefig(os.path.join(outdir, "corner.png"), bbox_inches = 'tight') +plt.close() + +end_time = time.time() +runtime_seconds = end_time - start_time +number_of_minutes = runtime_seconds // 60 +number_of_seconds = np.round(runtime_seconds % 60, 2) +print(f"Total runtime: {number_of_minutes} m {number_of_seconds} s") + +print("Plotting lightcurves") +fiesta.plot_lightcurves() +print("Plotting lightcurves . . . done") + +print("DONE") \ No newline at end of file diff --git a/examples/GRB/injection_gaussian/corner.png b/examples/GRB/injection_gaussian/corner.png new file mode 100644 index 0000000..04f8e90 Binary files /dev/null and b/examples/GRB/injection_gaussian/corner.png differ diff --git a/examples/GRB/injection_gaussian/injection_gaussian.dat b/examples/GRB/injection_gaussian/injection_gaussian.dat new file mode 100644 index 0000000..0b6aba0 --- /dev/null +++ b/examples/GRB/injection_gaussian/injection_gaussian.dat @@ -0,0 +1,70 @@ +2020-01-02T00:00:00.000 radio-3GHz 6.473406 0.500000 +2020-04-10T22:48:00.000 radio-3GHz 12.246090 0.500000 +2020-07-19T21:36:00.000 radio-3GHz 13.860922 0.500000 +2020-10-27T20:24:00.000 radio-3GHz 15.145347 0.500000 +2021-02-04T19:12:00.000 radio-3GHz 16.200114 0.500000 +2021-05-15T18:00:00.000 radio-3GHz 17.037075 0.500000 +2021-08-23T16:48:00.000 radio-3GHz 17.709738 0.500000 +2021-12-01T15:36:00.000 radio-3GHz 18.265443 0.500000 +2022-03-11T14:24:00.000 radio-3GHz 18.738014 0.500000 +2022-06-19T13:12:00.000 radio-3GHz 19.142404 0.500000 +2022-09-27T12:00:00.000 radio-3GHz 19.496765 0.500000 +2023-01-05T10:48:00.000 radio-3GHz 19.810069 0.500000 +2023-04-15T09:36:00.000 radio-3GHz 20.091715 0.500000 +2023-07-24T08:24:00.000 radio-3GHz 20.346828 0.500000 +2023-11-01T07:12:00.000 radio-3GHz 20.578875 0.500000 +2024-02-09T06:00:00.000 radio-3GHz 20.791826 0.500000 +2024-05-19T04:48:00.000 radio-3GHz 20.989639 0.500000 +2024-08-27T03:36:00.000 radio-3GHz 21.171644 0.500000 +2024-12-05T02:24:00.000 radio-3GHz 21.341979 0.500000 +2025-03-15T01:12:00.000 radio-3GHz 21.500904 0.500000 +2025-06-23T00:00:00.000 radio-3GHz 21.652456 0.500000 +2020-01-02T00:00:00.000 radio-6GHz 6.642397 0.500000 +2020-05-05T22:30:00.000 radio-6GHz 13.306995 0.500000 +2020-09-07T21:00:00.000 radio-6GHz 15.133077 0.500000 +2021-01-10T19:30:00.000 radio-6GHz 16.559615 0.500000 +2021-05-15T18:00:00.000 radio-6GHz 17.639135 0.500000 +2021-09-17T16:30:00.000 radio-6GHz 18.460232 0.500000 +2022-01-20T15:00:00.000 radio-6GHz 19.112092 0.500000 +2022-05-25T13:30:00.000 radio-6GHz 19.646921 0.500000 +2022-09-27T12:00:00.000 radio-6GHz 20.098825 0.500000 +2023-01-30T10:30:00.000 radio-6GHz 20.487108 0.500000 +2023-06-04T09:00:00.000 radio-6GHz 20.824446 0.500000 +2023-10-07T07:30:00.000 radio-6GHz 21.125451 0.500000 +2024-02-09T06:00:00.000 radio-6GHz 21.393886 0.500000 +2024-06-13T04:30:00.000 radio-6GHz 21.637560 0.500000 +2024-10-16T03:00:00.000 radio-6GHz 21.859776 0.500000 +2025-02-18T01:30:00.000 radio-6GHz 22.064226 0.500000 +2025-06-23T00:00:00.000 radio-6GHz 22.254516 0.500000 +2020-01-02T00:00:00.000 X-ray-1keV 21.209144 0.500000 +2020-04-06T04:34:17.143 X-ray-1keV 27.962903 0.500000 +2020-07-10T09:08:34.286 X-ray-1keV 29.537845 0.500000 +2020-10-13T13:42:51.429 X-ray-1keV 30.789609 0.500000 +2021-01-16T18:17:08.571 X-ray-1keV 31.828374 0.500000 +2021-04-21T22:51:25.714 X-ray-1keV 32.666384 0.500000 +2021-07-26T03:25:42.857 X-ray-1keV 33.344842 0.500000 +2021-10-29T08:00:00.000 X-ray-1keV 33.904345 0.500000 +2022-02-01T12:34:17.143 X-ray-1keV 34.378401 0.500000 +2022-05-07T17:08:34.286 X-ray-1keV 34.787830 0.500000 +2022-08-10T21:42:51.429 X-ray-1keV 35.144792 0.500000 +2022-11-14T02:17:08.571 X-ray-1keV 35.464131 0.500000 +2023-02-17T06:51:25.714 X-ray-1keV 35.747110 0.500000 +2023-05-23T11:25:42.857 X-ray-1keV 36.004242 0.500000 +2023-08-26T16:00:00.000 X-ray-1keV 36.238919 0.500000 +2023-11-29T20:34:17.143 X-ray-1keV 36.454925 0.500000 +2024-03-04T01:08:34.286 X-ray-1keV 36.654020 0.500000 +2024-06-07T05:42:51.429 X-ray-1keV 36.837070 0.500000 +2024-09-10T10:17:08.571 X-ray-1keV 37.010718 0.500000 +2024-12-14T14:51:25.714 X-ray-1keV 37.169814 0.500000 +2025-03-19T19:25:42.857 X-ray-1keV 37.321108 0.500000 +2025-06-23T00:00:00.000 X-ray-1keV 37.465122 0.500000 +2020-01-02T00:00:00.000 bessellv 15.916906 0.500000 +2020-08-11T02:40:00.000 bessellv 24.685939 0.500000 +2021-03-21T05:20:00.000 bessellv 27.115049 0.500000 +2021-10-29T08:00:00.000 bessellv 28.612108 0.500000 +2022-06-08T10:40:00.000 bessellv 29.619683 0.500000 +2023-01-16T13:20:00.000 bessellv 30.363185 0.500000 +2023-08-26T16:00:00.000 bessellv 30.946682 0.500000 +2024-04-04T18:40:00.000 bessellv 31.423505 0.500000 +2024-11-12T21:20:00.000 bessellv 31.826433 0.500000 +2025-06-23T00:00:00.000 bessellv 32.172885 0.500000 diff --git a/examples/GRB/injection_gaussian/lightcurves.png b/examples/GRB/injection_gaussian/lightcurves.png new file mode 100644 index 0000000..4fa798c Binary files /dev/null and b/examples/GRB/injection_gaussian/lightcurves.png differ diff --git a/examples/GRB/injection_gaussian/results_production.npz b/examples/GRB/injection_gaussian/results_production.npz new file mode 100644 index 0000000..5f59524 Binary files /dev/null and b/examples/GRB/injection_gaussian/results_production.npz differ diff --git a/examples/GRB/injection_gaussian/results_training.npz b/examples/GRB/injection_gaussian/results_training.npz new file mode 100644 index 0000000..00be9b3 Binary files /dev/null and b/examples/GRB/injection_gaussian/results_training.npz differ diff --git a/examples/GRB/injection_tophat.py b/examples/GRB/injection_tophat.py index fd17ed0..202fb2e 100644 --- a/examples/GRB/injection_tophat.py +++ b/examples/GRB/injection_tophat.py @@ -1,7 +1,7 @@ -"""Test run on GRB170817A data.""" +"""Injection runs with afterglowpy tophat""" import os -import jax +import jax print(f"GPU found? {jax.devices()}") import jax.numpy as jnp jax.config.update("jax_enable_x64", True) @@ -9,12 +9,12 @@ import matplotlib.pyplot as plt import corner -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.inference.lightcurve_model import AfterglowpyPCA from fiesta.inference.injection import InjectionRecoveryAfterglowpy from fiesta.inference.likelihood import EMLikelihood -from fiesta.inference.prior import Uniform, Composite +from fiesta.inference.prior import Uniform, CompositePrior from fiesta.inference.fiesta import Fiesta -from fiesta.utils import load_event_data +from fiesta.utils import load_event_data, write_event_data import time start_time = time.time() @@ -57,23 +57,17 @@ save=False, truth_color="red") -############# -### SETUP ### -############# - - - ############## ### MODEL ### ############## name = "tophat" -model_dir = f"../trained_models/afterglowpy/{name}/" +model_dir = f"../../flux_models/afterglowpy_{name}/model" FILTERS = ["radio-3GHz", "radio-6GHz", "X-ray-1keV", "bessellv"] -model = AfterglowpyLightcurvemodel(name, - model_dir, - filters = FILTERS) +model = AfterglowpyPCA(name, + model_dir, + filters = FILTERS) ################### @@ -81,11 +75,17 @@ ### AFTERGLOWPY ### ################### +trigger_time = 58849 # 01-01-2020 in mjd +remake_injection = False +injection_dict = {"inclination_EM": 0.32, "log10_E0": 53.79, "thetaCore": 0.1, "p": 2.47, "log10_n0": -2.1, "log10_epsilon_e": -1.326, "log10_epsilon_B": -3.89, "luminosity_distance": 40.0} + +if remake_injection: + injection = InjectionRecoveryAfterglowpy(injection_dict, jet_type = -1, filters = FILTERS, N_datapoints = 70, error_budget = 0.5, tmin = 1, tmax = 2000, trigger_time = trigger_time) + injection.create_injection() + data = injection.data + write_event_data("./injection_tophat/injection_tophat.dat", data) -injection_dict = {"inclination_EM": 1.3, "log10_E0": 52, "thetaCore": 0.2, "p": 2.5, "log10_n0": -1., "log10_epsilon_e": -1., "log10_epsilon_B": -4., "luminosity_distance": 40.0} -injection = InjectionRecoveryAfterglowpy(injection_dict, filters = FILTERS, N_datapoints = 50, error_budget = 0.5, tmin = 8, tmax = 800) -injection.create_injection() -data = injection.data +data = load_event_data("./injection_tophat/injection_tophat.dat") ############################# @@ -93,18 +93,18 @@ ############################# inclination_EM = Uniform(xmin=0.0, xmax=np.pi/2, naming=['inclination_EM']) -log10_E0 = Uniform(xmin=46.0, xmax=55.0, naming=['log10_E0']) -thetaCore = Uniform(xmin=0.01, xmax=np.pi/10, naming=['thetaCore']) -log10_n0 = Uniform(xmin=-7.0, xmax=1.0, naming=['log10_n0']) +log10_E0 = Uniform(xmin=47.0, xmax=57.0, naming=['log10_E0']) +thetaCore = Uniform(xmin=0.01, xmax=np.pi/5, naming=['thetaCore']) +log10_n0 = Uniform(xmin=-6.0, xmax=2.0, naming=['log10_n0']) p = Uniform(xmin=2.01, xmax=3.0, naming=['p']) -log10_epsilon_e = Uniform(xmin=-5.0, xmax=0.0, naming=['log10_epsilon_e']) +log10_epsilon_e = Uniform(xmin=-4.0, xmax=0.0, naming=['log10_epsilon_e']) log10_epsilon_B = Uniform(xmin=-8.0, xmax=0.0, naming=['log10_epsilon_B']) # luminosity_distance = Uniform(xmin=30.0, xmax=50.0, naming=['luminosity_distance']) prior_list = [inclination_EM, log10_E0, - thetaCore, + thetaCore, log10_n0, p, log10_epsilon_e, @@ -112,17 +112,17 @@ # luminosity_distance ] -prior = Composite(prior_list) +prior = CompositePrior(prior_list) detection_limit = None likelihood = EMLikelihood(model, data, FILTERS, - tmax = 500.0, - trigger_time=0, + tmax = 2000.0, + trigger_time=trigger_time, detection_limit = detection_limit, - fixed_params={"luminosity_distance": 40.0} -) + fixed_params={"luminosity_distance": 40.0}, + error_budget = 1e-5) ############## ### FIESTA ### @@ -133,7 +133,7 @@ local_sampler_arg = {"step_size": mass_matrix * eps} # Save for postprocessing -outdir = f"./injection_tophat/" +outdir = f"./injection_{name}/" if not os.path.exists(outdir): os.makedirs(outdir) diff --git a/examples/GRB/injection_tophat/corner.png b/examples/GRB/injection_tophat/corner.png index 9b4aa30..e8eb1bb 100644 Binary files a/examples/GRB/injection_tophat/corner.png and b/examples/GRB/injection_tophat/corner.png differ diff --git a/examples/GRB/injection_tophat/injection_tophat.dat b/examples/GRB/injection_tophat/injection_tophat.dat new file mode 100644 index 0000000..6e2f5fd --- /dev/null +++ b/examples/GRB/injection_tophat/injection_tophat.dat @@ -0,0 +1,70 @@ +2020-01-02T00:00:00.000 radio-3GHz 43.183030 0.500000 +2020-04-28T14:07:03.529 radio-3GHz 12.897802 0.500000 +2020-08-24T04:14:07.059 radio-3GHz 14.176549 0.500000 +2020-12-19T18:21:10.588 radio-3GHz 15.600486 0.500000 +2021-04-16T08:28:14.118 radio-3GHz 16.770202 0.500000 +2021-08-11T22:35:17.647 radio-3GHz 17.667325 0.500000 +2021-12-07T12:42:21.176 radio-3GHz 18.362647 0.500000 +2022-04-04T02:49:24.706 radio-3GHz 18.920391 0.500000 +2022-07-30T16:56:28.235 radio-3GHz 19.378047 0.500000 +2022-11-25T07:03:31.765 radio-3GHz 19.763130 0.500000 +2023-03-22T21:10:35.294 radio-3GHz 20.095193 0.500000 +2023-07-18T11:17:38.824 radio-3GHz 20.382696 0.500000 +2023-11-13T01:24:42.353 radio-3GHz 20.639471 0.500000 +2024-03-09T15:31:45.882 radio-3GHz 20.868518 0.500000 +2024-07-05T05:38:49.412 radio-3GHz 21.075244 0.500000 +2024-10-30T19:45:52.941 radio-3GHz 21.262475 0.500000 +2025-02-25T09:52:56.471 radio-3GHz 21.434785 0.500000 +2025-06-23T00:00:00.000 radio-3GHz 21.595849 0.500000 +2020-01-02T00:00:00.000 radio-6GHz 42.932172 0.500000 +2020-05-23T18:51:25.714 radio-6GHz 13.649171 0.500000 +2020-10-13T13:42:51.429 radio-6GHz 15.357011 0.500000 +2021-03-05T08:34:17.143 radio-6GHz 16.939701 0.500000 +2021-07-26T03:25:42.857 radio-6GHz 18.103814 0.500000 +2021-12-15T22:17:08.571 radio-6GHz 18.961353 0.500000 +2022-05-07T17:08:34.286 radio-6GHz 19.613470 0.500000 +2022-09-27T12:00:00.000 radio-6GHz 20.132481 0.500000 +2023-02-17T06:51:25.714 radio-6GHz 20.557361 0.500000 +2023-07-10T01:42:51.429 radio-6GHz 20.916358 0.500000 +2023-11-29T20:34:17.143 radio-6GHz 21.226381 0.500000 +2024-04-20T15:25:42.857 radio-6GHz 21.496434 0.500000 +2024-09-10T10:17:08.571 radio-6GHz 21.738757 0.500000 +2025-01-31T05:08:34.286 radio-6GHz 21.953505 0.500000 +2025-06-23T00:00:00.000 radio-6GHz 22.148992 0.500000 +2020-01-02T00:00:00.000 X-ray-1keV 45.244683 0.500000 +2020-04-06T04:34:17.143 X-ray-1keV 27.419029 0.500000 +2020-07-10T09:08:34.286 X-ray-1keV 28.154821 0.500000 +2020-10-13T13:42:51.429 X-ray-1keV 29.331756 0.500000 +2021-01-16T18:17:08.571 X-ray-1keV 30.434724 0.500000 +2021-04-21T22:51:25.714 X-ray-1keV 31.344310 0.500000 +2021-07-26T03:25:42.857 X-ray-1keV 32.078558 0.500000 +2021-10-29T08:00:00.000 X-ray-1keV 32.676029 0.500000 +2022-02-01T12:34:17.143 X-ray-1keV 33.170617 0.500000 +2022-05-07T17:08:34.286 X-ray-1keV 33.588215 0.500000 +2022-08-10T21:42:51.429 X-ray-1keV 33.944921 0.500000 +2022-11-14T02:17:08.571 X-ray-1keV 34.258470 0.500000 +2023-02-17T06:51:25.714 X-ray-1keV 34.532106 0.500000 +2023-05-23T11:25:42.857 X-ray-1keV 34.777604 0.500000 +2023-08-26T16:00:00.000 X-ray-1keV 34.999168 0.500000 +2023-11-29T20:34:17.143 X-ray-1keV 35.201125 0.500000 +2024-03-04T01:08:34.286 X-ray-1keV 35.385382 0.500000 +2024-06-07T05:42:51.429 X-ray-1keV 35.554281 0.500000 +2024-09-10T10:17:08.571 X-ray-1keV 35.713502 0.500000 +2024-12-14T14:51:25.714 X-ray-1keV 35.857474 0.500000 +2025-03-19T19:25:42.857 X-ray-1keV 35.994215 0.500000 +2025-06-23T00:00:00.000 X-ray-1keV 36.123736 0.500000 +2020-01-02T00:00:00.000 bessellv 38.912444 0.500000 +2020-05-23T18:51:25.714 bessellv 22.761672 0.500000 +2020-10-13T13:42:51.429 bessellv 24.469513 0.500000 +2021-03-05T08:34:17.143 bessellv 26.052203 0.500000 +2021-07-26T03:25:42.857 bessellv 27.216316 0.500000 +2021-12-15T22:17:08.571 bessellv 28.073855 0.500000 +2022-05-07T17:08:34.286 bessellv 28.725972 0.500000 +2022-09-27T12:00:00.000 bessellv 29.244983 0.500000 +2023-02-17T06:51:25.714 bessellv 29.669863 0.500000 +2023-07-10T01:42:51.429 bessellv 30.028860 0.500000 +2023-11-29T20:34:17.143 bessellv 30.338882 0.500000 +2024-04-20T15:25:42.857 bessellv 30.608936 0.500000 +2024-09-10T10:17:08.571 bessellv 30.851259 0.500000 +2025-01-31T05:08:34.286 bessellv 31.066006 0.500000 +2025-06-23T00:00:00.000 bessellv 31.261493 0.500000 diff --git a/examples/GRB/injection_tophat/lightcurves.png b/examples/GRB/injection_tophat/lightcurves.png index 3df8a4f..c139248 100644 Binary files a/examples/GRB/injection_tophat/lightcurves.png and b/examples/GRB/injection_tophat/lightcurves.png differ diff --git a/examples/GRB/injection_tophat/results_production.npz b/examples/GRB/injection_tophat/results_production.npz index 73ece35..094da02 100644 Binary files a/examples/GRB/injection_tophat/results_production.npz and b/examples/GRB/injection_tophat/results_production.npz differ diff --git a/examples/GRB/injection_tophat/results_training.npz b/examples/GRB/injection_tophat/results_training.npz index 8557168..2d5249c 100644 Binary files a/examples/GRB/injection_tophat/results_training.npz and b/examples/GRB/injection_tophat/results_training.npz differ diff --git a/examples/KN/outdir_AT2017gfo_Bu2019lm/results_production.npz b/examples/KN/outdir_AT2017gfo_Bu2019lm/results_production.npz index c8d60bd..25d81f6 100644 Binary files a/examples/KN/outdir_AT2017gfo_Bu2019lm/results_production.npz and b/examples/KN/outdir_AT2017gfo_Bu2019lm/results_production.npz differ diff --git a/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py new file mode 100644 index 0000000..4f35b0c --- /dev/null +++ b/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py @@ -0,0 +1,48 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.BenchmarkerFluxes import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.utils import Filter + + +name = "gaussian" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + +for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: + if metric_name == "$\\mathcal{L}_2$": + file_ending = "L2" + else: + file_ending = "Linf" + + B = Benchmarker(name = name, + model_dir = model_dir, + MODEL = AfterglowpyPCA, + filters = FILTERS, + metric_name = metric_name, + ) + + + for filt in FILTERS: + + fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\\alpha_w$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) + fig.savefig(f"./benchmarks/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + + B.print_correlations(filter = filt) + + fig, ax = B.plot_worst_lightcurves() + fig.savefig(f"./benchmarks/worst_lightcurves_{file_ending}.pdf", dpi = 200) + + +fig, ax = B.plot_error_distribution() +fig.savefig(f"./benchmarks/error_distribution.pdf", dpi = 200) + +fig, ax = B.plot_error_over_time() +fig.savefig(f"./benchmarks/error_over_time.pdf", dpi = 200) + + + + + + diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png new file mode 100644 index 0000000..eca6b33 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png new file mode 100644 index 0000000..9c48ca3 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png new file mode 100644 index 0000000..52ebbfd Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png new file mode 100644 index 0000000..6ea54ae Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf new file mode 100644 index 0000000..5a83f0d Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf new file mode 100644 index 0000000..9feb22c Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf new file mode 100644 index 0000000..1c1ed14 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf new file mode 100644 index 0000000..6e18a1b Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf new file mode 100644 index 0000000..6b12b59 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf new file mode 100644 index 0000000..7fc0988 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf new file mode 100644 index 0000000..e06c6ed Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf new file mode 100644 index 0000000..b682087 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf new file mode 100644 index 0000000..56450fd Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf new file mode 100644 index 0000000..4da5137 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf new file mode 100644 index 0000000..3d77ef1 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf new file mode 100644 index 0000000..5b83c50 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf new file mode 100644 index 0000000..403316f Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf new file mode 100644 index 0000000..ed53697 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf new file mode 100644 index 0000000..9972a81 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png b/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png new file mode 100644 index 0000000..b82a885 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf new file mode 100644 index 0000000..8d1f376 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf new file mode 100644 index 0000000..d846900 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf differ diff --git a/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py new file mode 100644 index 0000000..6041f03 --- /dev/null +++ b/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py @@ -0,0 +1,67 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.AfterglowData import AfterglowpyData + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 # days +n_times = 200 + + +numin = 1e9 # Hz +numax = 2.5e18 # Hz (10 keV) +n_nu = 256 + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'alphaWing': (0.2, 3.5, "uniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8, 0, "uniform") +} + + + +jet_name = "gaussian" +jet_conversion = {"tophat": -1, + "gaussian": 0} + +n_training = 45_000 +n_val = 0 +n_test = 0 + +n_pool = 24 + + + +####################### +### CREATE RAW DATA ### +####################### +name = jet_name +outdir = f"./model/" + +jet_type = jet_conversion[jet_name] + + + +creator = AfterglowpyData(outdir = outdir, + jet_type = jet_type, + n_training = n_training, + n_val = n_val, + n_test = n_test, + parameter_distributions = parameter_distributions, + n_pool = n_pool, + tmin = tmin, + tmax = tmax, + n_times = n_times, + numin = numin, + numax = numax, + n_nu = n_nu) \ No newline at end of file diff --git a/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py new file mode 100644 index 0000000..74c2ca1 --- /dev/null +++ b/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py @@ -0,0 +1,64 @@ +import numpy as np +from fiesta.train.AfterglowData import AfterglowpyData + +############# +### SETUP ### +############# + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'alphaWing': (0.2, 3.5, "uniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8, 0, "uniform") +} + + + +name = "tophat" +outdir = f"./model/" + +n_training = 0 +n_val = 0 +n_test = 0 + +n_pool = 24 +size = 20_000 + + +####################### +### CREATE RAW DATA ### +####################### +creator = AfterglowpyData(outdir = outdir, + n_training = 0, + n_val = 0, + n_test = 0, + n_pool = n_pool) + +#import h5py +#with h5py.File(creator.outfile, "r+") as f: +# unproblematic = np.unique(np.where(~np.isinf(f["special_train"]["01"]["y"]))[0]) +# +# X = f["special_train"]["01"]["X"][unproblematic] +# y = f["special_train"]["01"]["y"][unproblematic] +# breakpoint() +# creator._save_to_file(X, y, group = "special_train", label = "02", comment = "log10_E0 (54, 57) log10_n0 (-6, -4) thetaCore (0.4, np.pi/5)") + + + +inclination = np.random.uniform(0, np.pi/2, size = size) +log10_E0 = np.random.uniform(54, 57, size = size) +thetaCore = np.random.uniform(0.4, np.pi/5, size= size) +alphaWing = np.random.uniform(0.2, 3.5, size = size) +log10_n0 = np.random.uniform(-6, -4, size = size) +p = np.random.uniform(2, 3, size = size) +log10_epsilon_e = np.random.uniform(-4, 0, size = size) +log10_epsilon_B = np.random.uniform(-8, 0, size = size) + +X = np.array([inclination, log10_E0, thetaCore, alphaWing, log10_n0, p, log10_epsilon_e, log10_epsilon_B]).T + +creator.create_special_data(X, label = "01", comment = "log10_E0 (54, 57) log10_n0 (-6, -4) thetaCore (0.4, np.pi/5)") \ No newline at end of file diff --git a/flux_models/afterglowpy_gaussian/model/gaussian.pkl b/flux_models/afterglowpy_gaussian/model/gaussian.pkl new file mode 100644 index 0000000..7293261 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/model/gaussian.pkl differ diff --git a/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl b/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl new file mode 100644 index 0000000..a41a1f5 Binary files /dev/null and b/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl differ diff --git a/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py new file mode 100644 index 0000000..82a3371 --- /dev/null +++ b/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py @@ -0,0 +1,107 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.FluxTrainer import PCATrainer, DataManager +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.train.neuralnets import NeuralnetConfig +from fiesta.utils import Filter + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 + + +numin = 1e9 # Hz +numax = 2.5e18 + + +n_training = 70_000 +n_val = 5000 +n_pca = 100 + +name = "gaussian" +outdir = f"./model/" +file = outdir + "afterglowpy_raw_data.h5" + +config = NeuralnetConfig(output_size=n_pca, + nb_epochs=100_000, + hidden_layer_sizes = [256, 512, 256], + learning_rate =8e-3) + + +############### +### TRAINER ### +############### + + +data_manager = DataManager(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=["02"]) + +data_manager.print_file_info() +trainer = PCATrainer(name, + outdir, + data_manager = data_manager, + plots_dir=f"./benchmarks/", + n_pca = n_pca, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowpyPCA(name, + outdir, + filters = FILTERS) + +for filt in lc_model.Filters: + with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, data_manager.mask] + + y_raw = y_raw.reshape(256, len(lc_model.times)) + y_raw = np.exp(y_raw) + y_raw = np.array([np.interp(filt.nu, lc_model.metadata["nus"], column) for column in y_raw.T]) + y_raw = -48.6 + -1 * np.log10(y_raw*1e-3 / 1e23) * 2.5 + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + y_predict = lc_model.predict(X_example)[filt.name] + + plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") + plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") + upper_bound = y_predict + 1 + lower_bound = y_predict - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/afterglowpy_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/afterglowpy_tophat/benchmark_afterglowpy_tophat.py b/flux_models/afterglowpy_tophat/benchmark_afterglowpy_tophat.py new file mode 100644 index 0000000..b46aaf9 --- /dev/null +++ b/flux_models/afterglowpy_tophat/benchmark_afterglowpy_tophat.py @@ -0,0 +1,45 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.BenchmarkerFluxes import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.utils import Filter + + +name = "tophat" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + +for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: + if metric_name == "$\\mathcal{L}_2$": + file_ending = "L2" + else: + file_ending = "Linf" + + B = Benchmarker(name = name, + model_dir = model_dir, + MODEL = AfterglowpyPCA, + filters = FILTERS, + metric_name = metric_name, + ) + + + for filt in FILTERS: + + fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) + fig.savefig(f"./benchmarks/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + + B.print_correlations(filter = filt) + + fig, ax = B.plot_worst_lightcurves() + fig.savefig(f"./benchmarks/worst_lightcurves_{file_ending}.pdf", dpi = 200) + + +fig, ax = B.plot_error_distribution() +fig.savefig(f"./benchmarks/error_distribution.pdf", dpi = 200) + +fig, ax = B.plot_error_over_time() +fig.savefig(f"./benchmarks/error_over_time.pdf", dpi = 200) + + + diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png new file mode 100644 index 0000000..be608bd Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png new file mode 100644 index 0000000..2a3fdcb Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png new file mode 100644 index 0000000..ffbd76a Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png new file mode 100644 index 0000000..f9fbacc Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf new file mode 100644 index 0000000..986818e Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf new file mode 100644 index 0000000..975cb2f Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf new file mode 100644 index 0000000..a5348c1 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf new file mode 100644 index 0000000..833cec6 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf new file mode 100644 index 0000000..82cbf50 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf new file mode 100644 index 0000000..214a1ac Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf new file mode 100644 index 0000000..4d7d3d9 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf new file mode 100644 index 0000000..8180d59 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf new file mode 100644 index 0000000..8d2eb66 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf new file mode 100644 index 0000000..43f9395 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf new file mode 100644 index 0000000..89c23e9 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf new file mode 100644 index 0000000..af74378 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf new file mode 100644 index 0000000..ec5718c Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf new file mode 100644 index 0000000..83213b9 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf new file mode 100644 index 0000000..7030bd8 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png b/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png new file mode 100644 index 0000000..36ea0fe Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf new file mode 100644 index 0000000..506ca89 Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf new file mode 100644 index 0000000..7cfd65b Binary files /dev/null and b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf differ diff --git a/flux_models/afterglowpy_tophat/create_data_afterglowpy_tophat.py b/flux_models/afterglowpy_tophat/create_data_afterglowpy_tophat.py new file mode 100644 index 0000000..8003d74 --- /dev/null +++ b/flux_models/afterglowpy_tophat/create_data_afterglowpy_tophat.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.AfterglowData import AfterglowpyData + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 # days +n_times = 200 + + +numin = 1e9 # Hz +numax = 2.5e18 # Hz (10keV) +n_nu = 256 + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8, 0, "uniform") +} + + + +jet_name = "tophat" +jet_conversion = {"tophat": -1, + "gaussian": 0} + +n_training = 0 +n_val = 0 +n_test = 0 + +n_pool = 1 + + + +####################### +### CREATE RAW DATA ### +####################### +name = jet_name +outdir = f"./model/" + +jet_type = jet_conversion[jet_name] + +creator = AfterglowpyData(outdir = outdir, + jet_type = jet_type, + n_training = n_training, + n_val = n_val, + n_test = n_test, + parameter_distributions = parameter_distributions, + n_pool = n_pool, + tmin = tmin, + tmax = tmax, + n_times = n_times, + numin = numin, + numax = numax, + n_nu = n_nu) diff --git a/flux_models/afterglowpy_tophat/create_special_data_afterglowpy_tophat.py b/flux_models/afterglowpy_tophat/create_special_data_afterglowpy_tophat.py new file mode 100644 index 0000000..87e05e6 --- /dev/null +++ b/flux_models/afterglowpy_tophat/create_special_data_afterglowpy_tophat.py @@ -0,0 +1,52 @@ +import numpy as np +from fiesta.train.AfterglowData import AfterglowpyData + +############# +### SETUP ### +############# + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8, 0, "uniform") +} + + + +name = "tophat" +outdir = f"./model/" + +n_training = 0 +n_val = 0 +n_test = 0 + +n_pool = 24 +size = 5000 + + +####################### +### CREATE RAW DATA ### +####################### +creator = AfterglowpyData(outdir = outdir, + n_training = 0, + n_val = 0, + n_test = 0, + n_pool = n_pool) + + +inclination = np.random.uniform(0, np.pi/2, size = size) +log10_E0 = np.random.uniform(54, 57, size = size) +thetaCore = np.random.uniform(0.01, np.pi/5, size= size) +log10_n0 = np.random.uniform(-6, 2, size = size) +p = np.random.uniform(2,3,size = size) +log10_epsilon_e = np.random.uniform(-3, 0, size = size) +log10_epsilon_B = np.random.uniform(-3, 0, size = size) + +X = np.array([inclination, log10_E0, thetaCore, log10_n0, p, log10_epsilon_e, log10_epsilon_B]).T + +creator.create_special_data(X, label = "02", comment = "higher E0, epsilon_e, epsilon_B") \ No newline at end of file diff --git a/flux_models/afterglowpy_tophat/model/tophat.pkl b/flux_models/afterglowpy_tophat/model/tophat.pkl new file mode 100644 index 0000000..1f1f1c5 Binary files /dev/null and b/flux_models/afterglowpy_tophat/model/tophat.pkl differ diff --git a/flux_models/afterglowpy_tophat/model/tophat_metadata.pkl b/flux_models/afterglowpy_tophat/model/tophat_metadata.pkl new file mode 100644 index 0000000..59c785d Binary files /dev/null and b/flux_models/afterglowpy_tophat/model/tophat_metadata.pkl differ diff --git a/flux_models/afterglowpy_tophat/train_afterglowpy_tophat.py b/flux_models/afterglowpy_tophat/train_afterglowpy_tophat.py new file mode 100644 index 0000000..da6f48f --- /dev/null +++ b/flux_models/afterglowpy_tophat/train_afterglowpy_tophat.py @@ -0,0 +1,106 @@ +import numpy as np +import matplotlib.pyplot as plt +from scipy import stats +import h5py + +from fiesta.train.FluxTrainer import PCATrainer, DataManager +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.train.neuralnets import NeuralnetConfig +from fiesta.utils import Filter + + +############# +### SETUP ### +############# + +tmin = 1 # days +tmax = 2000 + +numin = 1e9 # Hz +numax = 2.5e18 + +n_training = 30_000 +n_val = 5000 +n_pca = 100 + +name = "tophat" +outdir = f"./model/" +file = outdir + "afterglowpy_raw_data.h5" + + +config = NeuralnetConfig(output_size=n_pca, + nb_epochs=100_000, + hidden_layer_sizes = [256, 512, 256], + learning_rate =8e-3) + + +############### +### TRAINER ### +############### + + +data_manager = DataManager(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=["02"]) + +trainer = PCATrainer(name, + outdir, + data_manager = data_manager, + plots_dir=f"./benchmarks/", + n_pca = n_pca, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowpyPCA(name, + outdir, + filters = FILTERS) + +for filt in lc_model.Filters: + with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, data_manager.mask] + + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + y_raw = np.exp(y_raw) + y_raw = np.array([np.interp(filt.nu, lc_model.metadata["nus"], column) for column in y_raw.T]) + y_raw = -48.6 + -1 * np.log10(y_raw*1e-3 / 1e23) * 2.5 + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + y_predict = lc_model.predict(X_example)[filt.name] + + plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") + plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") + upper_bound = y_predict + 1 + lower_bound = y_predict - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/afterglowpy_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py new file mode 100644 index 0000000..95a5e76 --- /dev/null +++ b/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py @@ -0,0 +1,71 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.AfterglowData import PyblastafterglowData +from mpi4py import MPI +comm = MPI.COMM_WORLD +size = comm.Get_size() +rank = comm.Get_rank() + + +############# +### SETUP ### +############# + +tmin = 0.01 # days +tmax = 2000 +n_times = 250 + + +numin = 1e9 # Hz +numax = 2.5e19 # Hz (100 keV) +n_nu = 256 + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'alphaWing': (0.2, 3.5, "uniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8,0, "uniform"), + 'Gamma0': (100, 1000, "uniform") +} + + + +jet_name = "gaussian" +jet_conversion = {"tophat": -1, + "gaussian": 0} + +n_training = 10 +n_val = 2 +n_test = 2 + + +####################### +### CREATE RAW DATA ### +####################### +name = jet_name +outdir = f"./model/" + +jet_type = jet_conversion[jet_name] + + + +creator = PyblastafterglowData(outdir = outdir, + jet_type = jet_type, + n_training = n_training, + n_val = n_val, + n_test = n_test, + tmin = tmin, + tmax = tmax, + n_times = n_times, + numin = numin, + numax = numax, + n_nu = n_nu, + parameter_distributions = parameter_distributions, + rank = rank, + path_to_exec = "/home/aya/work/hkoehn/fiesta/PyBlastAfterglowMag/src/pba.out") \ No newline at end of file diff --git a/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py new file mode 100644 index 0000000..69f3af3 --- /dev/null +++ b/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py @@ -0,0 +1,45 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.BenchmarkerFluxes import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.utils import Filter + + +name = "tophat" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv"] + +for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: + if metric_name == "$\\mathcal{L}_2$": + file_ending = "L2" + else: + file_ending = "Linf" + + B = Benchmarker(name = name, + model_dir = model_dir, + MODEL = AfterglowpyPCA, + filters = FILTERS, + metric_name = metric_name, + ) + + + for filt in FILTERS: + + fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$", "$\\Gamma_0$"]) + fig.savefig(f"./benchmarks/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + + B.print_correlations(filter = filt) + + fig, ax = B.plot_worst_lightcurves() + fig.savefig(f"./benchmarks/worst_lightcurves_{file_ending}.pdf", dpi = 200) + + +fig, ax = B.plot_error_distribution() +fig.savefig(f"./benchmarks/error_distribution.pdf", dpi = 200) + +fig, ax = B.plot_error_over_time() +fig.savefig(f"./benchmarks/error_over_time.pdf", dpi = 200) + + + diff --git a/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py new file mode 100644 index 0000000..f6d2a64 --- /dev/null +++ b/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py @@ -0,0 +1,72 @@ +import numpy as np + +from fiesta.train.AfterglowData import PyblastafterglowData +from mpi4py import MPI +comm = MPI.COMM_WORLD +size = comm.Get_size() +rank = comm.Get_rank() + + +############# +### SETUP ### +############# + +tmin = 0.01 # days +tmax = 2000 +n_times = 250 + + +numin = 1e9 # Hz +numax = 2.5e19 # Hz (100 keV) +n_nu = 256 + + +parameter_distributions = { + 'inclination_EM': (0, np.pi/2, "uniform"), + 'log10_E0': (47, 57, "uniform"), + 'thetaCore': (0.01, np.pi/5, "loguniform"), + 'log10_n0': (-6, 2, "uniform"), + 'p': (2.01, 3, "uniform"), + 'log10_epsilon_e': (-4, 0, "uniform"), + 'log10_epsilon_B': (-8,0, "uniform"), + 'Gamma0': (100, 1000, "uniform") +} + + + +jet_name = "tophat" +jet_conversion = {"tophat": -1, + "gaussian": 0, + "powerlaw": 4} + +n_training = 100 +n_val = 10 +n_test = 10 + +retrain_weights = None + + +####################### +### CREATE RAW DATA ### +####################### +name = jet_name +outdir = f"./model/" + +jet_type = jet_conversion[jet_name] + + + +creator = PyblastafterglowData(outdir = outdir, + jet_type = jet_type, + n_training = n_training, + n_val = n_val, + n_test = n_test, + tmin = tmin, + tmax = tmax, + n_times = n_times, + numin = numin, + numax = numax, + n_nu = n_nu, + parameter_distributions = parameter_distributions, + rank = rank, + path_to_exec = "/hppfs/scratch/06/di35kuf/pyblastafterglow/PyBlastAfterglowMag/src/pba.out") diff --git a/flux_models/pyblastafterglow_tophat/join_h5files.py b/flux_models/pyblastafterglow_tophat/join_h5files.py new file mode 100644 index 0000000..881db49 --- /dev/null +++ b/flux_models/pyblastafterglow_tophat/join_h5files.py @@ -0,0 +1,37 @@ +import os +import h5py +import numpy as np +import shutil +import tqdm + +##### + +directory = "./model" + +##### + +outfile = os.path.join(directory, "pyblastafterglow_raw_data.h5") +file_list = [f for f in os.listdir(directory) if f.endswith(".h5")] +shutil.copy(os.path.join(directory, file_list[0]), outfile) + +with h5py.File(outfile, "a") as f: + + for file in tqdm.tqdm(file_list[1:]): + file = h5py.File(os.path.join(directory, file)) + for group in ["train", "val", "test"]: + X = file[group]["X"] + Xset = f[group]["X"] + Xset.resize(Xset.shape[0]+X.shape[0], axis = 0) + Xset[-X.shape[0]:] = X + + y = file[group]["y"] + yset = f[group]["y"] + yset.resize(yset.shape[0]+y.shape[0], axis = 0) + yset[-y.shape[0]:] = y + file.close() + + print("train: ", f["train"]["y"].shape[0]) + print("val: ", f["val"]["y"].shape[0]) + print("test: ", f["test"]["y"].shape[0]) + + diff --git a/flux_models/pyblastafterglow_tophat/plot_training_data_distribution.py b/flux_models/pyblastafterglow_tophat/plot_training_data_distribution.py new file mode 100644 index 0000000..8c5f390 --- /dev/null +++ b/flux_models/pyblastafterglow_tophat/plot_training_data_distribution.py @@ -0,0 +1,36 @@ +import numpy as np +import corner +import matplotlib.pyplot as plt +import h5py + +default_corner_kwargs = dict(bins=40, + smooth=1., + label_kwargs=dict(fontsize=16), + title_kwargs=dict(fontsize=16), + #color = "blue" + # quantiles=[], + # levels=[0.9], + plot_density=False, + plot_datapoints=True, + plot_contours = False, + fill_contours=False, + max_n_ticks=4, + min_n_ticks=3, + save=False, + truth_color="red") + +####### + +file = "./model/pyblastafterglow_raw_data.h5" + +####### + + +with h5py.File(file, "r") as f: + X = f["train"]["X"][:] + parameter_names = f["parameter_names"][()] + + +corner.corner(X, **default_corner_kwargs, labels = parameter_names) +plt.show() +breakpoint() \ No newline at end of file diff --git a/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py new file mode 100644 index 0000000..978c919 --- /dev/null +++ b/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py @@ -0,0 +1,101 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.FluxTrainer import PCATrainer, DataManager +from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.train.neuralnets import NeuralnetConfig +from fiesta.utils import Filter + +############# +### SETUP ### +############# + +tmin = 1 # days +tmax = 2000 + + +numin = 1e9 # Hz +numax = 1e17 + +n_training = 50_000 +n_val = 5000 +n_pca = 100 + +name = "tophat" +outdir = f"./model/" +file = outdir + "pyblastafterglow_raw_data.h5" + +config = NeuralnetConfig(output_size=n_pca, + nb_epochs=100_000, + hidden_layer_sizes = [256, 512, 256], + learning_rate =8e-3) + +############### +### TRAINER ### +############### + + +data_manager = DataManager(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=[]) + +trainer = PCATrainer(name, + outdir, + data_manager = data_manager, + plots_dir=f"./benchmarks/", + n_pca = n_pca, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowpyPCA(name, + outdir, + filters = FILTERS) + +for filt in lc_model.Filters: + with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, data_manager.mask] + + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + y_raw = np.exp(y_raw) + y_raw = np.array([np.interp(filt.nu, lc_model.metadata["nus"], column) for column in y_raw.T]) + y_raw = -48.6 + -1 * np.log10(y_raw*1e-3 / 1e23) * 2.5 + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + y_predict = lc_model.predict(X_example)[filt.name] + + plt.plot(lc_model.times, y_raw, color = "red", label="pyblast_afterglow") + plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") + upper_bound = y_predict + 1 + lower_bound = y_predict - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.legend() + plt.gca().invert_yaxis() + + plt.savefig(f"./benchmarks/pyblastafterglow_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/transfer_data_from_pkl_to_h5.py b/flux_models/transfer_data_from_pkl_to_h5.py new file mode 100644 index 0000000..6a47be5 --- /dev/null +++ b/flux_models/transfer_data_from_pkl_to_h5.py @@ -0,0 +1,44 @@ +import numpy as np +import pickle +import h5py +import os +from prompt_toolkit import prompt + + +filein = "./afterglowpy/gaussian/afterglowpy_raw_data.pkl" +fileout = "./afterglowpy/gaussian/afterglowpy_raw_data.h5" + +with open(filein, "rb") as file: + data = pickle.load(file) + jet_type = data["jet_type"] + times = data["times"] + nus = data["nus"] + parameter_distributions = data["parameter_distributions"] + parameter_names = data["parameter_names"] + train_X_raw, train_y_raw = data["train_X_raw"], data["train_y_raw"] + val_X_raw, val_y_raw = data["val_X_raw"], data["val_y_raw"] + test_X_raw, test_y_raw = data["test_X_raw"], data["test_y_raw"] + +print("Loaded data from file.") +if os.path.exists(fileout): + user_input = prompt("Warning, will overwrite existing h5 file. Continue?") + user_input = user_input.strip().lower() + if user_input not in ["y", "yes"]: + exit() + +with h5py.File(fileout, "w") as f: + f.create_dataset("times", data = times) + f.create_dataset("nus", data = nus) + f.create_dataset("parameter_names", data = parameter_names) + f.create_dataset("parameter_distributions", data = str(parameter_distributions)) + f.create_dataset("jet_type", data = jet_type) + f.create_group("train"); f.create_group("val"); f.create_group("test"); f.create_group("special_train") + + f["train"].create_dataset("X", data = train_X_raw, maxshape=(None, len(parameter_names)), chunks = (1000, len(parameter_names))) + f["train"].create_dataset("y", data = train_y_raw, maxshape=(None, len(times)*len(nus)), chunks = (1000, len(times)*len(nus))) + + f["val"].create_dataset("X", data = val_X_raw, maxshape=(None, len(parameter_names)), chunks = (1000, len(parameter_names))) + f["val"].create_dataset("y", data = val_y_raw, maxshape=(None, len(times)*len(nus)), chunks = (1000, len(times)*len(nus))) + + f["test"].create_dataset("X", data = test_X_raw, maxshape=(None, len(parameter_names)), chunks = (1000, len(parameter_names))) + f["test"].create_dataset("y", data = test_y_raw, maxshape=(None, len(times)*len(nus)), chunks = (1000, len(times)*len(nus))) \ No newline at end of file diff --git a/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl b/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl new file mode 100644 index 0000000..73a4949 Binary files /dev/null and b/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl b/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl new file mode 100644 index 0000000..bc6b620 Binary files /dev/null and b/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl b/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl new file mode 100644 index 0000000..ae9d474 Binary files /dev/null and b/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl b/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl new file mode 100644 index 0000000..5d14206 Binary files /dev/null and b/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl b/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl new file mode 100644 index 0000000..ca8e7dc Binary files /dev/null and b/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl differ diff --git a/trained_models/GRB/train_afterglowpy_tophat.py b/lightcurve_models/GRB/train_afterglowpy_tophat.py similarity index 76% rename from trained_models/GRB/train_afterglowpy_tophat.py rename to lightcurve_models/GRB/train_afterglowpy_tophat.py index 3af5a8b..3ec2de3 100644 --- a/trained_models/GRB/train_afterglowpy_tophat.py +++ b/lightcurve_models/GRB/train_afterglowpy_tophat.py @@ -2,6 +2,7 @@ import matplotlib.pyplot as plt from fiesta.train.SurrogateTrainer import AfterglowpyTrainer +from fiesta.train.Benchmarker import Benchmarker from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel from fiesta.train.neuralnets import NeuralnetConfig from fiesta.utils import Filter @@ -44,15 +45,16 @@ """ -FILTERS = ["radio-3GHz", "radio-6GHz"] +FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] +FILTERS = ["radio-6GHz"] parameter_grid = { - 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], - 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], - 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], - 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], - 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], - 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], - 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] + 'inclination_EM': np.linspace(0, np.pi/4, 12), + 'log10_E0': np.linspace(47, 56, 19), + 'thetaCore': np.logspace(-2, np.log10(np.pi/5), 12), + 'log10_n0': np.linspace(-6, 2, 17), + 'p': np.linspace(2.01, 3.0, 10), + 'log10_epsilon_e': np.linspace(-4, 0, 9), + 'log10_epsilon_B': np.linspace(-8, 0, 9) } @@ -68,17 +70,33 @@ ### TRAINER ### ############### +B = Benchmarker(name = name, + parameter_grid = parameter_grid, + model_dir = outdir, + filters = FILTERS, + n_test_data = 2000, + metric_name = "$\\mathcal{L}_\infty$", + remake_test_data = False, + jet_type = jet_conversion[jet_name], + ) + +ww = B.error_distribution + +weight_grids = ww["radio-6GHz"] + # TODO: perhaps also want to train on the full LC, without the SVD? # TODO: train to output flux, not the mag? trainer = AfterglowpyTrainer(name, outdir, FILTERS, parameter_grid, + weight_grids = weight_grids, jet_type = jet_conversion[jet_name], tmin = tmin, tmax = tmax, - plots_dir="./figures/", - svd_ncoeff=40, + use_log_spacing = True, + plots_dir=f"./benchmarks/{name}", + svd_ncoeff=30, save_raw_data=True, save_preprocessed_data=True, remake_training_data = True, @@ -127,5 +145,5 @@ plt.legend() plt.gca().invert_yaxis() - plt.savefig(f"./figures/afterglowpy_{name}_{filt}_example.png") + plt.savefig(f"./benchmarks/{name}/afterglowpy_{name}_{filt}_example.png") plt.close() \ No newline at end of file diff --git a/trained_models/KN/Bu2019lm/.gitignore b/lightcurve_models/KN/Bu2019lm/.gitignore similarity index 100% rename from trained_models/KN/Bu2019lm/.gitignore rename to lightcurve_models/KN/Bu2019lm/.gitignore diff --git a/trained_models/KN/Bu2019lm/2massh.pkl b/lightcurve_models/KN/Bu2019lm/2massh.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/2massh.pkl rename to lightcurve_models/KN/Bu2019lm/2massh.pkl diff --git a/trained_models/KN/Bu2019lm/2massj.pkl b/lightcurve_models/KN/Bu2019lm/2massj.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/2massj.pkl rename to lightcurve_models/KN/Bu2019lm/2massj.pkl diff --git a/trained_models/KN/Bu2019lm/2massks.pkl b/lightcurve_models/KN/Bu2019lm/2massks.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/2massks.pkl rename to lightcurve_models/KN/Bu2019lm/2massks.pkl diff --git a/trained_models/KN/Bu2019lm/Bu2019lm_metadata.pkl b/lightcurve_models/KN/Bu2019lm/Bu2019lm_metadata.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/Bu2019lm_metadata.pkl rename to lightcurve_models/KN/Bu2019lm/Bu2019lm_metadata.pkl diff --git a/trained_models/KN/Bu2019lm/ps1__g.pkl b/lightcurve_models/KN/Bu2019lm/ps1__g.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/ps1__g.pkl rename to lightcurve_models/KN/Bu2019lm/ps1__g.pkl diff --git a/trained_models/KN/Bu2019lm/ps1__i.pkl b/lightcurve_models/KN/Bu2019lm/ps1__i.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/ps1__i.pkl rename to lightcurve_models/KN/Bu2019lm/ps1__i.pkl diff --git a/trained_models/KN/Bu2019lm/ps1__r.pkl b/lightcurve_models/KN/Bu2019lm/ps1__r.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/ps1__r.pkl rename to lightcurve_models/KN/Bu2019lm/ps1__r.pkl diff --git a/trained_models/KN/Bu2019lm/ps1__y.pkl b/lightcurve_models/KN/Bu2019lm/ps1__y.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/ps1__y.pkl rename to lightcurve_models/KN/Bu2019lm/ps1__y.pkl diff --git a/trained_models/KN/Bu2019lm/ps1__z.pkl b/lightcurve_models/KN/Bu2019lm/ps1__z.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/ps1__z.pkl rename to lightcurve_models/KN/Bu2019lm/ps1__z.pkl diff --git a/trained_models/KN/Bu2019lm/sdssu.pkl b/lightcurve_models/KN/Bu2019lm/sdssu.pkl similarity index 100% rename from trained_models/KN/Bu2019lm/sdssu.pkl rename to lightcurve_models/KN/Bu2019lm/sdssu.pkl diff --git a/trained_models/KN/train_Bu2019lm.py b/lightcurve_models/KN/train_Bu2019lm.py similarity index 100% rename from trained_models/KN/train_Bu2019lm.py rename to lightcurve_models/KN/train_Bu2019lm.py diff --git a/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl b/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl new file mode 100644 index 0000000..767b6f5 Binary files /dev/null and b/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl differ diff --git a/lightcurve_models/afterglowpy/gaussian/bessellv.pkl b/lightcurve_models/afterglowpy/gaussian/bessellv.pkl new file mode 100644 index 0000000..6e6b36a Binary files /dev/null and b/lightcurve_models/afterglowpy/gaussian/bessellv.pkl differ diff --git a/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl b/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl new file mode 100644 index 0000000..73e0bd1 Binary files /dev/null and b/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl differ diff --git a/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl b/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl new file mode 100644 index 0000000..5e6f67b Binary files /dev/null and b/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl differ diff --git a/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl b/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl new file mode 100644 index 0000000..6058076 Binary files /dev/null and b/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl differ diff --git a/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl b/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl new file mode 100644 index 0000000..73a4949 Binary files /dev/null and b/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl differ diff --git a/lightcurve_models/afterglowpy/tophat/bessellv.pkl b/lightcurve_models/afterglowpy/tophat/bessellv.pkl new file mode 100644 index 0000000..bc6b620 Binary files /dev/null and b/lightcurve_models/afterglowpy/tophat/bessellv.pkl differ diff --git a/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl b/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl new file mode 100644 index 0000000..ae9d474 Binary files /dev/null and b/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl differ diff --git a/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl b/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl new file mode 100644 index 0000000..5d14206 Binary files /dev/null and b/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl differ diff --git a/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl b/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl new file mode 100644 index 0000000..ca8e7dc Binary files /dev/null and b/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl differ diff --git a/lightcurve_models/benchmark_afterglowpy_gaussian.py b/lightcurve_models/benchmark_afterglowpy_gaussian.py new file mode 100644 index 0000000..14b2a74 --- /dev/null +++ b/lightcurve_models/benchmark_afterglowpy_gaussian.py @@ -0,0 +1,64 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.utils import Filter + + +name = "gaussian" +model_dir = f"./afterglowpy/{name}/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + +parameter_grid = { + 'inclination_EM': np.linspace(0, np.pi/4, 12), + 'log10_E0': np.linspace(47, 56, 19), + 'thetaWing': np.linspace(0.01, np.pi/5, 12), + 'xCore': np.linspace(0.05, 1, 20), + 'log10_n0': np.linspace(-6, 2, 17), + 'p': np.linspace(2.01, 3.0, 10), + 'log10_epsilon_e': np.linspace(-4, 0, 9), + 'log10_epsilon_B': np.linspace(-8, 0, 9) +} + +for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: + if metric_name == "$\\mathcal{L}_2$": + file_ending = "L2" + else: + file_ending = "Linf" + + B = Benchmarker(name = name, + parameter_grid = parameter_grid, + model_dir = model_dir, + MODEL = AfterglowpyLightcurvemodel, + filters = FILTERS, + n_test_data = 2000, + metric_name = metric_name, + remake_test_data = False, + jet_type = 0, + ) + + fig, ax = B.plot_error_distribution("radio-6GHz") + + + for filt in FILTERS: + + fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{w}}$", "$x_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) + fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200) + + B.print_correlations(filter = filt) + + + if metric_name == "$\\mathcal{L}_\infty$": + fig, ax = B.plot_error_distribution(filt) + fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200) + + + fig, ax = B.plot_worst_lightcurves() + fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200) + + + + + + diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png new file mode 100644 index 0000000..c21c817 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png new file mode 100644 index 0000000..306d201 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png new file mode 100644 index 0000000..c9fc842 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png new file mode 100644 index 0000000..b5cd9b8 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf new file mode 100644 index 0000000..48a6150 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf new file mode 100644 index 0000000..03d5c63 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf new file mode 100644 index 0000000..3d29635 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf new file mode 100644 index 0000000..85eba89 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf new file mode 100644 index 0000000..380e869 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf new file mode 100644 index 0000000..51b7614 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf new file mode 100644 index 0000000..0b859a3 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf new file mode 100644 index 0000000..3723121 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf new file mode 100644 index 0000000..1f15167 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf new file mode 100644 index 0000000..26e7d08 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf new file mode 100644 index 0000000..5be4ef7 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/correlations_before.txt b/lightcurve_models/benchmarks/gaussian/correlations_before.txt new file mode 100644 index 0000000..9566b17 --- /dev/null +++ b/lightcurve_models/benchmarks/gaussian/correlations_before.txt @@ -0,0 +1,108 @@ +Correlations for filter radio-3GHz: + +inclination_EM: 0.13648909886212593 +log10_E0: 0.33645145631212375 +thetaWing: 0.024265589755934246 +xCore: -0.0998657462258901 +log10_n0: -0.18244632046614845 +p: -0.12178999832348425 +log10_epsilon_e: 0.028199807808143303 +log10_epsilon_B: 0.11382668130975754 + + + +Correlations for filter radio-6GHz: + +inclination_EM: 0.07900887361655033 +log10_E0: 0.20617615787512814 +thetaWing: 0.0061923518520167874 +xCore: -0.1609517797089711 +log10_n0: -0.17330437332816612 +p: 0.06921641433003754 +log10_epsilon_e: 0.0010812683989245774 +log10_epsilon_B: 0.03164077022449176 + + + +Correlations for filter bessellv: + +inclination_EM: 0.10876851455212108 +log10_E0: 0.31350345293188253 +thetaWing: 0.06548400328241566 +xCore: -0.12897435382322575 +log10_n0: -0.24090642768022807 +p: -0.15163486371234677 +log10_epsilon_e: 0.051617546936042646 +log10_epsilon_B: 0.06200398370473675 + + + +Correlations for filter X-ray-1keV: + +inclination_EM: 0.0937755880248725 +log10_E0: 0.20646061403424484 +thetaWing: 0.029171682247324363 +xCore: -0.13630793956135479 +log10_n0: -0.11379719745935056 +p: -0.19303802440344964 +log10_epsilon_e: 0.04816653484776429 +log10_epsilon_B: 0.031864402035493226 + + + + + +Loaded SurrogateLightcurveModel with filters ['radio-3GHz', 'radio-6GHz', 'bessellv', 'X-ray-1keV'] + + + +Correlations for filter radio-3GHz: + +inclination_EM: 0.22729615522584115 +log10_E0: 0.2854912376107409 +thetaWing: -0.0655575192504771 +xCore: -0.1511753578064876 +log10_n0: -0.14326622677000614 +p: 0.003718908215404851 +log10_epsilon_e: 0.09145017016387194 +log10_epsilon_B: 0.13900251196367636 + + + +Correlations for filter radio-6GHz: + +inclination_EM: 0.15482259690303216 +log10_E0: 0.31706796584780933 +thetaWing: -0.04603719649371961 +xCore: -0.20588033447258053 +log10_n0: -0.19114656242441438 +p: 0.0195090896910469 +log10_epsilon_e: 0.08416371935468658 +log10_epsilon_B: 0.10182239506588713 + + + +Correlations for filter bessellv: + +inclination_EM: 0.131150012020771 +log10_E0: 0.26739763810111317 +thetaWing: -0.03328134154282784 +xCore: -0.20478505987062182 +log10_n0: -0.2029206212123128 +p: -0.053992018841870645 +log10_epsilon_e: 0.014003969294619653 +log10_epsilon_B: 0.0470155844120538 + + + +Correlations for filter X-ray-1keV: + +inclination_EM: 0.19483874480498534 +log10_E0: 0.23606543897286883 +thetaWing: -0.04549081935294193 +xCore: -0.2464464697288298 +log10_n0: -0.1599685189157919 +p: -0.10176554058762216 +log10_epsilon_e: 0.008516304543409538 +log10_epsilon_B: 0.05139817839737394 + diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf new file mode 100644 index 0000000..4f4b5f7 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf new file mode 100644 index 0000000..95fd2d6 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf new file mode 100644 index 0000000..5c373b0 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf new file mode 100644 index 0000000..dd7d55b Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png b/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png new file mode 100644 index 0000000..a15ecd3 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png b/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png new file mode 100644 index 0000000..0c216c4 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png new file mode 100644 index 0000000..861db19 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png new file mode 100644 index 0000000..c9a5622 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png differ diff --git a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf new file mode 100644 index 0000000..2974fc1 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf differ diff --git a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf new file mode 100644 index 0000000..0cee2b7 Binary files /dev/null and b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png new file mode 100644 index 0000000..d7ee02f Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png new file mode 100644 index 0000000..3d72a42 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png new file mode 100644 index 0000000..15cb698 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png new file mode 100644 index 0000000..443366e Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf new file mode 100644 index 0000000..b73ccc7 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf new file mode 100644 index 0000000..bcbfffb Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf new file mode 100644 index 0000000..6b44e71 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf new file mode 100644 index 0000000..6101d70 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf new file mode 100644 index 0000000..8835d68 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf new file mode 100644 index 0000000..65edfbf Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf new file mode 100644 index 0000000..35f3c78 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf new file mode 100644 index 0000000..4feced3 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf new file mode 100644 index 0000000..3a25d99 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf new file mode 100644 index 0000000..a2caf84 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf b/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf new file mode 100644 index 0000000..50b896f Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf differ diff --git a/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf b/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf new file mode 100644 index 0000000..9a07ac8 Binary files /dev/null and b/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf differ diff --git a/lightcurve_models/train_afterglowpy_gaussian.py b/lightcurve_models/train_afterglowpy_gaussian.py new file mode 100644 index 0000000..4d7c984 --- /dev/null +++ b/lightcurve_models/train_afterglowpy_gaussian.py @@ -0,0 +1,146 @@ +import numpy as np +import matplotlib.pyplot as plt + +from fiesta.train.SurrogateTrainer import AfterglowpyTrainer +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.train.neuralnets import NeuralnetConfig +from fiesta.utils import Filter + +############# +### SETUP ### +############# + +FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] +for filter in FILTERS: + filter = Filter(filter) + print(filter.name, filter.nu) + +tmin = 1 +tmax = 1000 +""" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] +parameter_grid = { + 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], + 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], + 'thetaWing': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], + 'xCore': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], + 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], + 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], + 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], + 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] +} +""" + +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] +parameter_grid = { + 'inclination_EM': np.linspace(0, np.pi/4, 12), + 'log10_E0': np.linspace(47, 56, 19), + 'thetaWing': np.logspace(-2, np.log10(np.pi/5), 12), + 'xCore': np.linspace(0.05, 1, 20), + 'log10_n0': np.linspace(-6, 2, 17), + 'p': np.linspace(2.01, 3.0, 10), + 'log10_epsilon_e': np.linspace(-4, 0, 9), + 'log10_epsilon_B': np.linspace(-8, 0, 9) +} + + + + + +jet_name = "gaussian" +jet_conversion = {"tophat": -1, + "gaussian": 0, + "powerlaw": 4} + +name = "gaussian" +outdir = f"./afterglowpy/{name}/" + +############### +### TRAINER ### +############### + + +# Benchmarker to get weights for the training data + +B = Benchmarker(name = name, + parameter_grid = parameter_grid, + model_dir = outdir, + filters = FILTERS, + n_test_data = 2000, + metric_name = "$\\mathcal{L}_\infty$", + remake_test_data = False, + jet_type = jet_conversion[jet_name], + ) + +ww = B.error_distribution + + +weight_grids = ww["X-ray-1keV"] + +#for p in parameter_grid.keys(): +# weight_grids[p] = np.average([ww[filt][p] for filt in FILTERS], axis = 0) + +# TODO: perhaps also want to train on the full LC, without the SVD? +# TODO: train to output flux, not the mag? +trainer = AfterglowpyTrainer(name, + outdir, + FILTERS, + parameter_grid, + weight_grids = weight_grids, + jet_type = jet_conversion[jet_name], + tmin = tmin, + tmax = tmax, + use_log_spacing = True, + plots_dir=f"./benchmarks/{name}", + svd_ncoeff=30, + save_raw_data=False, + save_preprocessed_data=False, + remake_training_data = False, + n_training_data = 10_000 + ) + +############### +### FITTING ### +############### + +config = NeuralnetConfig(output_size=trainer.svd_ncoeff, + nb_epochs=50_000, + hidden_layer_sizes = [64, 128, 64], + learning_rate =8e-3) + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") + +lc_model = AfterglowpyLightcurvemodel(name, + outdir, + filters = FILTERS) + +for filt in lc_model.filters: + X_example = trainer.val_X_raw[0] + y_raw = trainer.val_y_raw[filt][0] + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + y_predict = lc_model.predict(X_example)[filt] + + plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") + plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") + upper_bound = y_predict + 1 + lower_bound = y_predict - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt}") + plt.legend() + plt.gca().invert_yaxis() + + plt.savefig(f"./benchmarks/{name}/afterglowpy_{name}_{filt}_example.png") + plt.close() \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index d1e24e6..c2d4a9e 100644 --- a/setup.cfg +++ b/setup.cfg @@ -28,6 +28,7 @@ install_requires = sncosmo flowMC joblib + h5py python_requires = >=3.10 diff --git a/src/fiesta.egg-info/PKG-INFO b/src/fiesta.egg-info/PKG-INFO index 05d5959..da026b0 100644 --- a/src/fiesta.egg-info/PKG-INFO +++ b/src/fiesta.egg-info/PKG-INFO @@ -52,3 +52,7 @@ To train your own surrogate models, have a look at some of the example scripts i - `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model. - `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon. + +## Acknowledgements + +The logo was created by [ideogram AI](https://ideogram.ai/). diff --git a/src/fiesta/conversions.py b/src/fiesta/conversions.py index 33a12c6..230ce09 100644 --- a/src/fiesta/conversions.py +++ b/src/fiesta/conversions.py @@ -1,7 +1,7 @@ from fiesta.constants import pc_to_cm import jax import jax.numpy as jnp -from jaxtyping import Array +from jaxtyping import Array, Float import numpy as np def Mpc_to_cm(d: float): diff --git a/src/fiesta/inference/fiesta.py b/src/fiesta/inference/fiesta.py index 078dc81..e55b1c0 100644 --- a/src/fiesta/inference/fiesta.py +++ b/src/fiesta/inference/fiesta.py @@ -18,8 +18,6 @@ from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.utils.PRNG_keys import initialize_rng_keys -import time # TODO: remove me! - default_hyperparameters = { "seed": 0, "n_chains": 20, @@ -108,7 +106,7 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict): def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])): if initial_guess.size == 0: initial_guess_named = self.prior.sample(key, self.Sampler.n_chains) - initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T + initial_guess = jnp.stack([initial_guess_named[key] for key in self.prior.naming]).T self.Sampler.sample(initial_guess, None) # type: ignore def print_summary(self, transform: bool = True): diff --git a/src/fiesta/inference/injection.py b/src/fiesta/inference/injection.py index f62dfac..7b34aaf 100644 --- a/src/fiesta/inference/injection.py +++ b/src/fiesta/inference/injection.py @@ -14,7 +14,7 @@ from fiesta.constants import days_to_seconds, c from fiesta import conversions -import afterglowpy as grb +from fiesta.train.AfterglowData import RunAfterglowpy # TODO: get the parser going def get_parser(**kwargs): @@ -94,6 +94,7 @@ class InjectionRecoveryAfterglowpy: def __init__(self, injection_dict: dict[str, Float], + trigger_time: Float, filters: list[str], jet_type = -1, tmin: Float = 0.1, @@ -112,6 +113,7 @@ def __init__(self, self.filters = [Filter(filt) for filt in filters] print(f"Creating injection with filters: {filters}") self.injection_dict = injection_dict + self.trigger_time = trigger_time self.tmin = tmin self.tmax = tmax self.N_datapoints = N_datapoints @@ -122,57 +124,22 @@ def __init__(self, def create_injection(self): """Create a synthetic injection from the given model and parameters.""" - points = np.random.multinomial(self.N_datapoints, [1/len(self.filters)]*len(self.filters)) # random number of datapoints in each filter - self.data = {} - - for npoints, filt in zip(points, self.filters): - self.injection_dict["nu"] = filt.nu - times = self.create_timegrid(npoints) - mJys = self._call_afterglowpy(times*days_to_seconds, self.injection_dict) - magnitudes = conversions.mJys_to_mag_np(mJys) - mag_err = self.error_budget * np.ones_like(times) - self.data[filt.name] = np.array([times, magnitudes, mag_err]).T + nus = [filt.nu for filt in self.filters] + times = np.logspace(np.log10(self.tmin), np.log10(self.tmax), 200) + afgpy = RunAfterglowpy(self.jet_type, times, nus, [list(self.injection_dict.values())], self.injection_dict.keys()) + _, log_flux = afgpy(0) + mJys = np.exp(log_flux).reshape(len(nus), 200) - - def _call_afterglowpy(self, - times_afterglowpy: Array, - params_dict: dict[str, Float]) -> Float[Array, "n_times"]: - """ - Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. - The output will be a set of mJys. + self.data = {} + points = np.random.multinomial(self.N_datapoints, [1/len(self.filters)]*len(self.filters)) # random number of datapoints in each filter + for j, npoints, filt in zip(range(len(self.filters)), points, self.filters): + times_data = self.create_timegrid(npoints) + mJys_filter = np.interp(times_data, times, mJys[j]) + magnitudes = conversions.mJys_to_mag_np(mJys_filter) + magnitudes = magnitudes + 5 * np.log10(self.injection_dict["luminosity_distance"]/(10*1e-6)) - Args: - Float[Array, "n_times"]: The flux density in mJys at the given times. - """ - - # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z - Z = {} - - Z["jetType"] = params_dict.get("jetType", self.jet_type) - Z["specType"] = params_dict.get("specType", 0) - Z["z"] = params_dict.get("z", 0.0) - Z["xi_N"] = params_dict.get("xi_N", 1.0) - - Z["E0"] = 10 ** params_dict["log10_E0"] - Z["thetaCore"] = params_dict["thetaCore"] - Z["n0"] = 10 ** params_dict["log10_n0"] - Z["p"] = params_dict["p"] - Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] - Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] - Z["d_L"] = params_dict.get("luminosity_distance", 1e-5)*1e6*3.086e18 - if "inclination_EM" in list(params_dict.keys()): - Z["thetaObs"] = params_dict["inclination_EM"] - else: - Z["thetaObs"] = params_dict["thetaObs"] - if self.jet_type == 1 or self.jet_type == 4: - Z["b"] = params_dict["b"] - if "thetaWing" in list(params_dict.keys()): - Z["thetaWing"] = params_dict["thetaWing"] - - # Afterglowpy returns flux in mJys - mJys = grb.fluxDensity(times_afterglowpy, params_dict["nu"], **Z) - return mJys - + mag_err = self.error_budget * np.ones_like(times_data) + self.data[filt.name] = np.array([times_data + self.trigger_time, magnitudes, mag_err]).T def create_timegrid(self, npoints): """Create a time grid for the injection.""" diff --git a/src/fiesta/inference/lightcurve_model.py b/src/fiesta/inference/lightcurve_model.py index 4dfd7a0..9f448f1 100644 --- a/src/fiesta/inference/lightcurve_model.py +++ b/src/fiesta/inference/lightcurve_model.py @@ -13,8 +13,8 @@ import fiesta.train.neuralnets as fiesta_nn from fiesta.utils import MinMaxScalerJax, inverse_svd_transform -import fiesta.conversions as conversions from fiesta import models_utilities +from fiesta import utils ######################## ### ABSTRACT CLASSES ### @@ -77,7 +77,7 @@ def project_output(self, y: dict[str, Array]) -> dict[str, Array]: """ return y - @partial(jax.jit, static_argnums=(0,)) + @partial(jax.jit, static_argnums=(0,)) # TODO: the jit here can come into conflict with scikit-learn methods used in project_output, maybe only jit self.comput_output def predict(self, x: dict[str, Array]) -> dict[str, Array]: """ Generate the lightcurve y from the unnormalized and untransformed input x. @@ -86,7 +86,7 @@ def predict(self, x: dict[str, Array]) -> dict[str, Array]: x to x tilde and y to y tilde take care of projections (e.g. SVD projections) and normalizations. Args: - x (Array): Input array, unnormalized and untransformed. + x (dict[str, Array]): Input array, unnormalized and untransformed. Returns: Array: Output dict[str, Array], i.e., the desired raw light curve per filter @@ -132,8 +132,8 @@ def __init__(self, self.models = {} # Load the metadata for projections etc - self.load_filters(filters) self.load_metadata() + self.load_filters(filters) self.load_scalers() self.load_times(times) self.load_parameter_names() @@ -227,8 +227,7 @@ def project_output(self, y: dict[str, Array]) -> dict[str, Array]: Returns: dict[str, Array]: Output array transformed to the preprocessed space. """ - return {filter: self.y_scaler[filter].inverse_transform(y[filter]) for filter in self.filters} - + return {filter: self.y_scaler[filter].inverse_transform(y[filter]) for filter in self.filters} class SVDSurrogateLightcurveModel(SurrogateLightcurveModel): @@ -280,6 +279,124 @@ def load_parameter_names(self) -> None: class AfterglowpyLightcurvemodel(SVDSurrogateLightcurveModel): + def __init__(self, + name: str, + directory: str, + filters: list[str] = None, + times: Array = None): + super().__init__(name=name, directory=directory, filters=filters, times=times) + + def load_parameter_names(self) -> None: + self.parameter_names = self.metadata["parameter_names"] + +class PCALightcurveModel(SurrogateLightcurveModel): + + def __init__(self, + name: str, + directory: str, + filters: list[str] = None, + times: Array = None): + + super().__init__(name = name, directory= directory, filters = filters, times = times) + + def load_filters(self, filters: list[str] = None) -> None: + self.nus = self.metadata['nus'] + self.Filters = [] + for filter in filters: + try: + Filter = utils.Filter(filter) + if Filter.nuself.nus[-1]: + continue + self.Filters.append(Filter) + except: + raise Exception(f"Filter {filter} not available.") + + self.filters = [filt.name for filt in self.Filters] + if len(self.filters) == 0: + raise ValueError(f"No filters found that match the trained frequency range {self.nus[0]:.3e} Hz to {self.nus[-1]:.3e} Hz.") + + print(f"Loaded SurrogateLightcurveModel with filters {self.filters}.") + + def load_scalers(self): + self.X_scaler = self.metadata["X_scaler"] + self.y_scaler = self.metadata["y_scaler"] + #self.pca = self.metadata["pca"] + + def load_networks(self) -> None: + filename = os.path.join(self.directory, f"{self.name}.pkl") + state, _ = fiesta_nn.load_model(filename) + self.models = state + + + def project_input(self, x: Array) -> Array: + """ + Project the given input to whatever preprocessed input space we are in. + + Args: + x (Array): Original input array + + Returns: + Array: Transformed input array + """ + x_tilde = self.X_scaler.transform(x) + return x_tilde + + def compute_output(self, x: Array) -> Array: + """ + Apply the trained flax neural network on the given input x. + + Args: + x (dict[str, Array]): Input array of parameters per filter + + Returns: + dict[str, Array]: _description_ + """ + output = self.models.apply_fn({'params': self.models.params}, x) + return output + + def project_output(self, y: Array) -> dict[str, Array]: + """ + Project the computed output to whatever preprocessed output space we are in. + + Args: + y (dict[str, Array]): Output array + + Returns: + dict[str, Array]: Output array transformed to the preprocessed space. + """ + #y = self.pca.inverse_transform(y) + y = self.y_scaler.inverse_transform(y) + + y = jnp.reshape(y, shape = (len(self.metadata["nus"]), len(self.times)) ) + y = jnp.exp(y) + + output = {} + for filt in self.Filters: + mJys = jnp.array([jnp.interp(filt.nu, self.metadata["nus"], column) for column in y.T]) # TODO: get a check here that the filt.nu is in range of the meta data + mag = -48.6 + -1 * jnp.log10(mJys) * 2.5 + -1 * (-26) * 2.5 + output[filt.name] = mag + return output + + + def predict_log_flux(self, x: Array) -> Array: + """ + Predict the total log flux array for the parameters x. + + Args: + x [Array]: raw parameter array + + Returns: + log_flux [Array]: Array of log-fluxes. + """ + + x_tilde = self.X_scaler.transform(x) + y = self.models.apply_fn({'params': self.models.params}, x_tilde) + + logflux = self.y_scaler.inverse_transform(y) + return logflux + +class AfterglowpyPCA(PCALightcurveModel): + def __init__(self, name: str, directory: str, diff --git a/src/fiesta/inference/prior.py b/src/fiesta/inference/prior.py index 220c49c..746c2e0 100644 --- a/src/fiesta/inference/prior.py +++ b/src/fiesta/inference/prior.py @@ -145,6 +145,55 @@ def log_prob(self, x: dict[str, Array]) -> Float: jnp.zeros_like(variable), ) return output + jnp.log(1.0 / (self.xmax - self.xmin)) + + +@jaxtyped(typechecker=typechecker) +class Normal(Prior): + mu: float = 0.0 + sigma: float = 1.0 + + def __repr__(self): + return f"Normal(mu={self.mu}, sigma={self.sigma})" + + def __init__( + self, + mu: Float, + sigma: Float, + naming: list[str], + transforms: dict[str, tuple[str, Callable]] = {}, + **kwargs, + ): + super().__init__(naming, transforms) + assert self.n_dim == 1, "Normal needs to be 1D distributions" + self.mu = mu + self.sigma = sigma + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, " n_samples"]]: + """ + Sample from a normal distribution. + + Parameters + ---------- + rng_key : PRNGKeyArray + A random key to use for sampling. + n_samples : int + The number of samples to draw. + + Returns + ------- + samples : dict + Samples from the distribution. The keys are the names of the parameters. + + """ + samples = jax.random.normal(rng_key, (n_samples,),) + samples = self.mu + self.sigma * samples + return self.add_name(samples[None]) + + def log_prob(self, x: dict[str, Array]) -> Float: + variable = x[self.naming[0]] + return -1/(2*self.sigma**2) * (variable-self.mu)**2 - jnp.sqrt(2*jnp.pi*self.sigma**2) # class DiracDelta(Prior): @@ -168,7 +217,7 @@ def log_prob(self, x: dict[str, Array]) -> Float: # output = jnp.where(variable == self.value, jnp.zeros_like(variable), jnp.zeros_like(variable) - jnp.inf) # return output -class Composite(Prior): +class CompositePrior(Prior): priors: list[Prior] = field(default_factory=list) def __repr__(self): @@ -202,4 +251,25 @@ def log_prob(self, x: dict[str, Float]) -> Float: output = 0.0 for prior in self.priors: output += prior.log_prob(x) + return output + +class Constraint(Prior): + xmin: float + xmax: float + def __init__(self, + naming: list[str], + xmin: Float, + xmax: Float, + transforms: dict[str, tuple[str, Callable]] = {})->None: + super().__init__(naming = naming, transforms=transforms) + self.xmin = xmin + self.xmax = xmax + + def log_prob(self, x: dict[str, Array]) -> Float: + variable = x[self.naming[0]] + output = jnp.where( + (variable > self.xmax) | (variable < self.xmin), + jnp.zeros_like(variable) - jnp.inf, + jnp.zeros_like(variable), + ) return output \ No newline at end of file diff --git a/src/fiesta/inference/prior_dict.py b/src/fiesta/inference/prior_dict.py new file mode 100644 index 0000000..b273dd6 --- /dev/null +++ b/src/fiesta/inference/prior_dict.py @@ -0,0 +1,102 @@ +import jax +import jax.numpy as jnp + +from typing import Callable +from jaxtyping import Array, Float, PRNGKeyArray, jaxtyped +from .prior import Prior, Constraint, CompositePrior + + +class ConstrainedPrior(CompositePrior): + priors: CompositePrior + constraints: list[Constraint] + conversion: Callable + factor: Float + def __init__(self, priors: list, conversion_function: Callable = None, transforms: dict[str, tuple[str, Callable]] = {}): + + super().__init__([prior for prior in priors if not isinstance(prior, Constraint)]) + + self.constraints = [constraint for constraint in priors if isinstance(constraint, Constraint)] + + if conversion_function is None: + self.conversion = lambda x: x + else: + self.conversion = conversion_function + + self._estimate_normalization() + + def _estimate_normalization(self, nrepeats: int = 10, sampling_chunk: int = 50_000): + rng_key = jax.random.key(314159265) + factor_estimates = [] + for _ in range(nrepeats): + rng_key, subkey = jax.random.split(rng_key) + samples = super().sample(subkey, n_samples = sampling_chunk) + constr = ~jnp.isneginf(self.evaluate_constraints(samples)) + factor_estimates.append(sampling_chunk/jnp.sum(constr)) + factor_estimates = jnp.array(factor_estimates) + decimals = int( -jnp.floor(jnp.log10(3*jnp.std(factor_estimates))) ) + self.factor = jnp.round(jnp.mean(factor_estimates), decimals) + + def evaluate_constraints(self, samples): + converted_sample = self.conversion(samples) + log_prob = jnp.zeros_like(samples[self.naming[0]]) + for constraint in self.constraints: + log_prob+=constraint.log_prob(samples) + return log_prob + + def sample( + self, rng_key: PRNGKeyArray, n_samples: int + ) -> dict[str, Float[Array, "n_samples"]]: + + rng_key, subkey = jax.random.split(rng_key) + samples = super().sample(subkey, n_samples) + constr = ~jnp.isneginf(self.evaluate_constraints(samples)) + + while jnp.any(~constr): # not really jax-y but no idea atm how to do implement this logic better + idx = jnp.where(~constr, jnp.arange(n_samples), 0) + idx = jnp.unique(idx)# problems with jit here + rng_key, subkey = jax.random.split(rng_key) + new_samples = super().sample(subkey, idx.shape[0]) + new_constr = ~jnp.isneginf(self.evaluate_constraints(new_samples)) + def update_arrays(old_arr, new_arr): + return old_arr.at[idx].set(new_arr) + samples = jax.tree_util.tree_map(update_arrays, samples, new_samples) # update the samples dic by mapping update_arrays function over it + constr = constr.at[idx].set(new_constr) + + for constraint in self.constraints: + del samples[constraint.naming[0]] + + return samples + + + """ + def check_constraint(state): + _, constr, _ , _ = state + return jnp.all(constr) + + def update_samples(state): + samples, constr, rng_key, super = state + idx = jnp.where(~constr, jnp.arange(constr.shape[0]), 0) + rng_key, subkey = jax.random.split(rng_key) + new_samples = super.sample(subkey, jnp.sum(idx!=0)) + new_constr = self.evaluate_constraints(new_samples) + + samples = jax.tree_util.tree_map(update_arrays, samples, new_samples) + constr = constr.at[idx].set(new_constr) + return samples, constr, rng_key, super + + rng_key, subkey = jax.random.split(rng_key) + init_sample = super().sample(subkey, n_samples) + init_constr = ~jnp.isneginf(self.evaluate_constraints(init_sample)) + init_state = (init_sample, init_constr, rng_key, super()) + + final_state = jax.lax.while_loop(check_constraint, update_samples, init_state) + return final_state[0] + """ + + def log_prob(self, x: dict[str, Float]) -> Float: + output = self.evaluate_constraints(x) + for prior in self.priors: + output += prior.log_prob(x) + output += jnp.log(self.factor) + return output + diff --git a/src/fiesta/train/AfterglowData.py b/src/fiesta/train/AfterglowData.py new file mode 100644 index 0000000..57bb420 --- /dev/null +++ b/src/fiesta/train/AfterglowData.py @@ -0,0 +1,436 @@ +"""Method to train the surrogate models""" +import os +from xmlrpc.client import Boolean +import numpy as np +import ast +import h5py +import tqdm +from multiprocessing import Pool, Value + +from fiesta.constants import days_to_seconds +import afterglowpy as grb +from PyBlastAfterglowMag.wrappers import run_grb + + +class AfterglowData: + def __init__(self, + outdir: str, + n_training: int, + n_val: int, + n_test: int, + parameter_distributions: dict = None, + jet_type: int = -1, + tmin: float = 1., + tmax: float = 1000., + n_times: int = 100, + use_log_spacing: bool = True, + numin: float = 1e9, + numax: float = 2.5e18, + n_nu: int = 256, + fixed_parameters: dict = {}) -> None: + + self.outdir = outdir + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + self.outfile = os.path.join(self.outdir, self.outfile) + + self.n_training = n_training + self.n_val = n_val + self.n_test = n_test + + if os.path.exists(self.outfile): + self._read_file() + else: + self.jet_type = jet_type + if self.jet_type not in [-1,0]: + raise ValueError(f"Jet type {jet_type} is not supported. Supported jet types are: [-1, 0]") + self.initialize_times(tmin, tmax, n_times, use_log_spacing) # create time array + self.initialize_nus(numin, numax, n_nu) # create frequency array + self.parameter_names = list(parameter_distributions.keys()) + self.parameter_distributions = parameter_distributions + self._initialize_file() # initialize the h5 file the data is later written to + self.n_training_exists, self.n_val_exists, self.n_test_exists = 0, 0, 0 + + print(f"Initialized fiesta.train.AfterglowData \nJet type: {self.jet_type} \nParameters: {self.parameter_names} \nTimes: {self.times[0]} {self.times[-1]} {len(self.times)} \nNus: {self.nus[0]:.3e} {self.nus[-1]:.3e} {len(self.nus)} \nparameter_distributions: {self.parameter_distributions}\nExisting train, val, test: {self.n_training_exists}, {self.n_val_exists}, {self.n_test_exists} \n \n \n") + self.fixed_parameters = fixed_parameters + + self.get_raw_data(self.n_training, "train") # create new data and save it to file + self.get_raw_data(self.n_val, "val") + self.get_raw_data(self.n_test, "test") + + def initialize_times(self, tmin, tmax, n_times, use_log_spacing: bool = True): + if use_log_spacing: + times = np.logspace(np.log10(tmin), np.log10(tmax), num=n_times) + else: + times = np.linspace(tmin, tmax, num=n_times) + self.times = times + + def initialize_nus(self, numin: float, numax: float, n_nu: int): + self.nus = np.logspace(np.log10(numin), np.log10(numax), n_nu) + + def _initialize_file(self,): + with h5py.File(self.outfile, "w") as f: + f.create_dataset("times", data = self.times) + f.create_dataset("nus", data = self.nus) + f.create_dataset("parameter_names", data = self.parameter_names) + f.create_dataset("parameter_distributions", data = str(self.parameter_distributions)) + f.create_dataset("jet_type", data = self.jet_type) + f.create_group("train"); f.create_group("val"); f.create_group("test"); f.create_group("special_train") + + def get_raw_data(self, n: int, group: str): + if group == "train": + training = True + else: + training = False + + nchunks, rest = divmod(n, self.chunk_size) # create raw data in chunks of chunk_size + for chunk in tqdm.tqdm([*(nchunks*[self.chunk_size]), rest], desc = f"Calculating {nchunks+1} chunks of {group} data...", leave = True): + if chunk ==0: + continue + X, y = self.create_raw_data(chunk, training) + X, y = self.fix_nans(X, y) + self._save_to_file(X, y, group) + + def fix_nans(self,X,y): + # fixes any nans that remain from create_raw_data + problematic = np.unique(np.where(np.isnan(y))[0]) + n = len(problematic) + while n>0: + if n> 0.1*len(X): + print(f"Warning: Found many nans for the parameter samples, in total {n} out of {len(X)} samples.") + X_replacement, y_replacement = self.create_raw_data(n) + X[problematic] = X_replacement + y[problematic] = y_replacement + problematic = np.unique(np.where(np.isnan(y))[0]) + n = len(problematic) + return X, y + + def _read_file(self,): + with h5py.File(self.outfile, "r") as f: + self.jet_type = f["jet_type"][()] + self.times = f["times"][:] + self.nus = f["nus"][:] + self.parameter_names = f["parameter_names"][:].astype(str).tolist() + self.n_training_exists = (f["train"]["X"].shape)[0] + self.n_val_exists = (f["val"]["X"].shape)[0] + self.n_test_exists = (f["test"]["X"].shape)[0] + self.parameter_distributions = ast.literal_eval(f["parameter_distributions"][()].decode('utf-8')) + + def create_raw_data(self, n: int, training: bool = True): + """ + Create draws X in the parameter space and run the afterglow model on it. + """ + # Create training data + X_raw = np.empty((n, len(self.parameter_names))) + + if training: + for j, key in enumerate(self.parameter_names): + a, b, distribution = self.parameter_distributions[key] # FIXME + if distribution == "uniform": + X_raw[:,j] = np.random.uniform(a, b, size = n) + elif distribution == "loguniform": + X_raw[:,j] = np.exp(np.random.uniform(np.log(a), np.log(b), size = n)) + else: + for j, key in enumerate(self.parameter_distributions.keys()): + a, b, _ = self.parameter_distributions[key] + X_raw[:, j] = np.random.uniform(a, b, size = n) + + # Ensure that epsilon_e + epsilon_B < 1 + epsilon_e_ind = self.parameter_names.index("log10_epsilon_e") + epsilon_B_ind = self.parameter_names.index("log10_epsilon_B") + epsilon_tot = (10**(X_raw[:, epsilon_e_ind]) + 10**(X_raw[:, epsilon_B_ind])) + mask = epsilon_tot>=1 + X_raw[mask, epsilon_B_ind] += np.log10(0.99/epsilon_tot[mask]) + X_raw[mask, epsilon_e_ind] += np.log10(0.99/epsilon_tot[mask]) + + # Ensure that thetaWing is smaller than pi/2 + if self.jet_type !=-1: + alphaw_ind = self.parameter_names.index("alphaWing") + thetac_ind = self.parameter_names.index("thetaCore") + mask = X_raw[:, alphaw_ind]*X_raw[:, thetac_ind] >= np.pi/2 + X_raw[mask, alphaw_ind] = np.pi/2 * 1/X_raw[mask, thetac_ind] + + X, y = self.run_afterglow_model(X_raw) + return X, y + + def create_special_data(self, X_raw, label:str, comment: str = None): + """Create special training data with pre-specified parameters X. These will be stored in the 'special_train' hdf5 group.""" + + # Ensure that epsilon_e + epsilon_B < 1 + epsilon_e_ind = self.parameter_names.index("log10_epsilon_e") + epsilon_B_ind = self.parameter_names.index("log10_epsilon_B") + epsilon_tot = (10**(X_raw[:, epsilon_e_ind]) + 10**(X_raw[:, epsilon_B_ind])) + mask = epsilon_tot>=1 + X_raw[mask, epsilon_B_ind] += np.log10(0.99/epsilon_tot[mask]) + X_raw[mask, epsilon_e_ind] += np.log10(0.99/epsilon_tot[mask]) + + # Ensure that thetaWing is smaller than pi/2 + if self.jet_type !=-1: + alphaw_ind = self.parameter_names.index("alphaWing") + thetac_ind = self.parameter_names.index("thetaCore") + mask = X_raw[:, alphaw_ind]*X_raw[:, thetac_ind] >= np.pi/2 + X_raw[mask, alphaw_ind] = np.pi/2 * 1/X_raw[mask, thetac_ind] + + X, y = self.run_afterglow_model(X_raw) + X, y = self.fix_nans(X,y) + self._save_to_file(X, y, "special_train", label = label, comment= comment) + + def run_afterglow_model(X): + raise NotImplementedError + + def _save_to_file(self, X, y, group: str, label: str = None, comment: str = None): + with h5py.File(self.outfile, "a") as f: + if "y" in f[group]: + Xset = f[group]["X"] + Xset.resize(Xset.shape[0]+X.shape[0], axis = 0) + Xset[-X.shape[0]:] = X + + yset = f[group]["y"] + yset.resize(yset.shape[0]+y.shape[0], axis = 0) + yset[-y.shape[0]:] = y + + elif label is not None: # when we have special training data + if label in f["special_train"]: + Xset = f["special_train"][label]["X"] + Xset.resize(Xset.shape[0]+X.shape[0], axis = 0) + Xset[-X.shape[0]:] = X + + yset = f[group][label]["y"] + yset.resize(yset.shape[0]+y.shape[0], axis = 0) + yset[-y.shape[0]:] = y + + else: + f["special_train"].create_group(label) + if comment is not None: + f["special_train"][label].attrs["comment"] = comment + f["special_train"][label].create_dataset("X", data = X, maxshape=(None, len(self.parameter_names)), chunks = (self.chunk_size, len(self.parameter_names))) + f["special_train"][label].create_dataset("y", data = y, maxshape=(None, len(self.times)*len(self.nus)), chunks = (self.chunk_size, len(self.times)*len(self.nus))) + + else: + f[group].create_dataset("X", data = X, maxshape=(None, len(self.parameter_names)), chunks = (self.chunk_size, len(self.parameter_names))) + f[group].create_dataset("y", data = y, maxshape=(None, len(self.times)*len(self.nus)), chunks = (self.chunk_size, len(self.times)*len(self.nus))) + +class AfterglowpyData(AfterglowData): + + def __init__(self, n_pool: int, *args, **kwargs): + self.outfile = "afterglowpy_raw_data.h5" + self.n_pool = n_pool + self.chunk_size = 1000 + super().__init__(*args, **kwargs) + + + def run_afterglow_model(self, X): + """Uses multiprocessing to run afterglowpy on the supplied parameters in X.""" + y = np.empty((len(X), len(self.times)*len(self.nus))) + afgpy = RunAfterglowpy(self.jet_type, self.times, self.nus, X, self.parameter_names, self.fixed_parameters) + pool = Pool(processes=self.n_pool) + jobs = [pool.apply_async(func=afgpy, args=(argument,)) for argument in range(len(X))] + pool.close() + for Idx, job in enumerate(tqdm.tqdm(jobs, desc = f"Computing {len(X)} afterglowpy calculations.", leave = False)): + try: + idx, out = job.get() + y[idx] = out + except: + y[Idx] = np.full(len(self.times)*len(self.nus), np.nan) + return X, y + + +class PyblastafterglowData(AfterglowData): + + def __init__(self, path_to_exec: str, rank: int = 0, grb_resolution: int = 12, *args, **kwargs): + self.outfile = f"pyblastafterglow_raw_data_{rank}.h5" + self.chunk_size = 10 + self.rank = rank + self.path_to_exec = path_to_exec + self.grb_resolution = grb_resolution + super().__init__(*args, **kwargs) + + + def run_afterglow_model(self, X): + """Should be run in parallel with different mpi processes to run pyblastafterglow on the parameters in the array X.""" + y = np.empty((len(X), len(self.times)*len(self.nus))) + pbag = RunPyblastafterglow(self.jet_type, self.times, self.nus, X, self.parameter_names, self.fixed_parameters, rank=self.rank, path_to_exec = self.path_to_exec, grb_resolution = self.grb_resolution) + for j in tqdm.tqdm(range(len(X)), desc = f"Computing {len(X)} pyblastafterglow calculations.", leave = False): + try: + idx, out = pbag(j) + y[idx] = out + except: + try: + pbag.n_tb = 3000 # increase blast wave evolution time grid if there is an error + idx, out = pbag(j) + y[idx] = out + pbag.n_tb = 1000 + except: + y[j] = np.full(len(self.times)*len(self.nus), np.nan) + + return X, y + + +class RunAfterglowpy: + def __init__(self, jet_type, times, nus, X, parameter_names, fixed_parameters = {}): + self.jet_type = jet_type + self.times = times + self._times_afterglowpy = self.times * days_to_seconds # afterglowpy takes seconds as input + self.nus = nus + self.X = X + self.parameter_names = parameter_names + self.fixed_parameters = fixed_parameters + + def _call_afterglowpy(self, + params_dict: dict[str, float]): + """ + Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. + The output will be a set of mJys. + + Args: + Float[Array, "n_times"]: The flux density in mJys at the given times. + """ + + # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z + Z = {} + + Z["jetType"] = params_dict.get("jetType", self.jet_type) + Z["specType"] = params_dict.get("specType", 0) + Z["z"] = params_dict.get("z", 0.0) + Z["xi_N"] = params_dict.get("xi_N", 1.0) + + Z["E0"] = 10 ** params_dict["log10_E0"] + Z["n0"] = 10 ** params_dict["log10_n0"] + Z["p"] = params_dict["p"] + Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] + Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] + Z["d_L"] = 3.086e19 # fix at 10 pc, so that AB magnitude equals absolute magnitude + + if "inclination_EM" in list(params_dict.keys()): + Z["thetaObs"] = params_dict["inclination_EM"] + else: + Z["thetaObs"] = params_dict["thetaObs"] + + if self.jet_type == -1: + Z["thetaCore"] = params_dict["thetaCore"] + + elif self.jet_type == 0: + Z["thetaCore"] = params_dict["thetaCore"] + Z["thetaWing"] = params_dict["thetaCore"]*params_dict["alphaWing"] + + elif self.jet_type == 4: + Z["thetaCore"] = params_dict["thetaCore"] + Z["thetaWing"] = params_dict["thetaCore"]*params_dict["alphaWing"] + Z["b"] = params_dict["b"] + + else: + raise ValueError(f"Provided jet type {self.jet_type} invalid.") + + # Afterglowpy returns flux in mJys + tt, nunu = np.meshgrid(self._times_afterglowpy, self.nus) + mJys = grb.fluxDensity(tt, nunu, **Z) + return mJys + + def __call__(self, idx): + param_dict = dict(zip(self.parameter_names, self.X[idx])) + param_dict.update(self.fixed_parameters) + mJys = self._call_afterglowpy(param_dict) + return idx, np.log(mJys).flatten() + + + +class RunPyblastafterglow: + def __init__(self, jet_type, times, nus, X, parameter_names, fixed_parameters = {}, rank = 0, path_to_exec = "./pba.out", grb_resolution = 12): + self.jet_type = jet_type + jet_conversion = {"-1": "tophat", + "0": "gaussian"} + self.jet_type = jet_conversion[str(self.jet_type)] + times_seconds = times * days_to_seconds # pyblastafterglow takes seconds as input + + # preparing the pyblastafterglow string argument for time array + is_log_uniform = np.allclose(np.diff(np.log(times_seconds)), np.log(times_seconds[1])-np.log(times_seconds[0])) + if is_log_uniform: + log_dt = np.log(times_seconds[1])-np.log(times_seconds[0]) + self.lc_times = f'array logspace {times_seconds[0]:e} {np.exp(log_dt)*times_seconds[-1]:e} {len(times_seconds)}' # pyblastafterglow only takes this string format + else: + dt = times_seconds[1] - times_seconds[0] + self.lc_times = f'array uniform {times_seconds[0]:e} {times_seconds[-1]+dt:e} {len(times_seconds)}' + + # preparing the pyblastafterglow string argument for frequency array + log_dnu = np.log(nus[1]/nus[0]) + self.lc_freqs = f'array logspace {nus[0]:e} {np.exp(log_dnu)*nus[-1]:e} {len(nus)}' # pyblastafterglow only takes this string format + + self.X = X + self.parameter_names = parameter_names + self.fixed_parameters = fixed_parameters + self.rank = rank + self.path_to_exec = path_to_exec + self.grb_resolution = grb_resolution + self.n_tb = 1000 # set default blast wave evolution timegrid to 1000 + + def _call_pyblastafterglow(self, + params_dict: dict[str, float]): + """ + Run pyblastafterglow to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires. + The output will be a set of mJys. + + Args: + Float[Array, "n_times"]: The flux density in mJys at the given times. + """ + # Define jet structure (analytic; gaussian) -- 3 free parameters + struct = dict( + struct= self.jet_type, # type of the structure tophat or gaussian + Eiso_c=np.power(10, params_dict["log10_E0"]), # isotropic equivalent energy of the burst + Gamma0c=params_dict["Gamma0"], # lorentz factor of the core of the jet + M0c=-1., # mass of the ejecta (if -1 -- inferr from Eiso_c and Gamma0c) + n_layers_a=self.grb_resolution # resolution of the jet (number of individual blastwaves) + ) + + if self.jet_type == "tophat": + struct["theta_c"] = params_dict['thetaCore'] # half-opening angle of the winds of the jet + + elif self.jet_type == "gaussian": + struct["theta_c"] = params_dict['thetaCore'] # half-opening angle of the winds of the jet + struct["theta_w"] = params_dict["thetaCore"] * params_dict["alphaWing"] + + else: + raise ValueError(f"Provided jet type {self.jet_type} invalid.") + + # set model parameters + P = dict( + # main model parameters; Uniform ISM -- 2 free parameters + main=dict( + d_l= 3.086e19, # luminocity distance to the source [cm], fix at 10 pc, so that AB magnitude equals absolute magnitude + z = 0.0, # redshift of the source (used in Doppler shifring and EBL table) + n_ism=np.power(10, params_dict["log10_n0"]), # ISM density [cm^-3] (assuming uniform) + theta_obs= params_dict["inclination_EM"], # observer angle [rad] (from pol to jet axis) + lc_freqs= self.lc_freqs, # frequencies for light curve calculation + lc_times= self.lc_times, # times for light curve calculation + tb0=1e1, tb1=1e11, ntb=self.n_tb, # burster frame time grid boundary, resolution, for the simulation + ), + + # ejecta parameters; FS only -- 3 free parameters + grb=dict( + structure=struct, # structure of the ejecta + eps_e_fs=np.power(10, params_dict["log10_epsilon_e"]), # microphysics - FS - frac. energy in electrons + eps_b_fs=np.power(10, params_dict["log10_epsilon_B"]), # microphysics - FS - frac. energy in magnetic fields + p_fs= params_dict["p"], # microphysics - FS - slope of the injection electron spectrum + do_lc='yes', # task - compute light curves + rtol_theta = 1e-1, + # save_spec='yes' # save comoving spectra + # method_synchrotron_fs = 'Joh06', + # method_ne_fs = 'usenprime', + # method_ele_fs = 'analytic', + # method_comp_mode = 'observFlux' + ) + ) + pba_run = run_grb(working_dir= os.getcwd() + f'/tmp_{self.rank}/', # directory to save/load from simulation data + P=P, # all parameters + run=True, # run code itself (if False, it will try to load results) + path_to_cpp=self.path_to_exec, # absolute path to the C++ executable of the code + loglevel="err", # logging level of the code (info or err) + process_skymaps=False # process unstractured sky maps. Only useed if `do_skymap = yes` + ) + mJys = pba_run.GRB.get_lc() + return mJys + + def __call__(self, idx): + param_dict = dict(zip(self.parameter_names, self.X[idx])) + param_dict.update(self.fixed_parameters) + mJys = self._call_pyblastafterglow(param_dict) + return idx, np.log(mJys).flatten() diff --git a/src/fiesta/train/BenchmarkerFluxes.py b/src/fiesta/train/BenchmarkerFluxes.py new file mode 100644 index 0000000..5b5e70a --- /dev/null +++ b/src/fiesta/train/BenchmarkerFluxes.py @@ -0,0 +1,245 @@ +from fiesta.inference.lightcurve_model import LightcurveModel +import afterglowpy as grb +from fiesta.constants import days_to_seconds, c +from fiesta import conversions +from fiesta import utils + +from jaxtyping import Array, Float, Int + +import tqdm +import os +import ast +import h5py +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as colors +from matplotlib.cm import ScalarMappable + +from scipy.integrate import trapezoid +from scipy.interpolate import interp1d + +class Benchmarker: + + def __init__(self, + name: str, + model_dir: str, + filters: list[str], + MODEL = LightcurveModel, + metric_name: str = "$\\mathcal{L}_\\inf$" + ) -> None: + + self.name = name + self.model_dir = model_dir + self.model = MODEL(name = self.name, + directory = self.model_dir, + filters = filters) + + self.times = self.model.times + self.nus = self.model.metadata["nus"] + self.load_filters(filters) + + self.get_test_data() + self.lightcurve_test_data() + + + self.metric_name = metric_name + if metric_name == "$\\mathcal{L}_2$": + self.metric = lambda y: np.sqrt(trapezoid(x= self.times,y=y**2)) + self.metric2d = lambda y: np.sqrt(trapezoid(x = self.nus, y =trapezoid(x = self.times, y = (y**2).reshape(len(self.nus), len(self.times)) ) )) + else: + self.metric = lambda y: np.max(np.abs(y)) + self.metric2d = self.metric + + self.calculate_mismatch() + self.get_error_distribution() + + def __repr__(self) -> str: + return f"Surrogte_Benchmarker(name={self.name}, model_dir={self.model_dir})" + + def load_filters(self, filters: list[str]): + self.filters = [] + for filter in filters: + try: + self.filters.append(utils.Filter(filter)) + except: + raise Exception(f"Filter {filter} not available.") + + def get_test_data(self,): + + file = [f for f in os.listdir(self.model_dir) if f.endswith("_raw_data.h5")][0] + + with h5py.File(os.path.join(self.model_dir, file), "r") as f: + self.parameter_distributions = ast.literal_eval(f["parameter_distributions"][()].decode('utf-8')) + self.parameter_names = f["parameter_names"][:].astype(str).tolist() + self.test_X_raw = f["test"]["X"][:] + y_raw = f["test"]["y"][:] + y_raw = y_raw.reshape(len(self.test_X_raw), len(f["nus"]), len(f["times"]) ) + y_raw = interp1d(f["times"][:], y_raw, axis = 2)(self.times) # interpolate the test data over the time range of the model + y_raw = interp1d(f["nus"][:], y_raw, axis = 1)(self.nus) # interpolate the test data over the frequency range of the model + self.fluxes_raw = y_raw.reshape(len(self.test_X_raw), len(self.nus) * len(self.times) ) + + def lightcurve_test_data(self, ): + + self.n_test_data = len(self.test_X_raw) + self.prediction_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} + self.test_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} + self.prediction_log_fluxes = np.empty((self.n_test_data, len(self.nus) * len(self.times))) + for j, X in enumerate(self.test_X_raw): + param_dict = {name: x for name, x in zip(self.parameter_names, X)} + prediction = self.model.predict(param_dict) + self.prediction_log_fluxes[j] = self.model.predict_log_flux(X) + for filt in self.filters: + self.prediction_y_raw[filt.name][j] = prediction[filt.name] + self.test_y_raw[filt.name][j] = self.convert_to_mag(filt.nu, self.fluxes_raw[j]) + + def convert_to_mag(self, nu, flux): + flux = flux.reshape(len(self.model.metadata["nus"]), len(self.model.times)) + flux = np.exp(flux) + flux = np.array([np.interp(nu, self.model.metadata["nus"], column) for column in flux.T]) + mag = -48.6 + -1 * np.log10(flux*1e-3 / 1e23) * 2.5 + return mag + + + ### Diagnostics ### + + def calculate_mismatch(self): + mismatch = {} + for filt in self.filters: + array = np.empty(self.n_test_data) + for j in range(self.n_test_data): + array[j] = self.metric(self.prediction_y_raw[filt.name][j] - self.test_y_raw[filt.name][j]) + mismatch[filt.name] = array + + array = np.empty(self.n_test_data) + for j in range(self.n_test_data): + array[j] = self.metric2d(self.prediction_log_fluxes[j]-self.fluxes_raw[j]) + + mismatch["total"] = array + self.mismatch = mismatch + + def plot_lightcurves_mismatch(self, + filter: str, + parameter_labels: list[str] = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"] + ): + if self.metric_name == "$\\mathcal{L}_2$": + vline = np.sqrt(trapezoid(x = self.times, y = 0.2*np.ones(len(self.times)))) + vmin, vmax = 0, vline*2 + bins = np.linspace(vmin, vmax, 25) + else: + vline = 1. + vmin, vmax = 0, 2*vline + bins = np.linspace(vmin, vmax, 20) + + mismatch = self.mismatch[filter] + + cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) + colored_mismatch = cmap(mismatch/vmax) + + label_dic = {p: label for p, label in zip(self.parameter_names, parameter_labels)} + + fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1) + fig.suptitle(f"{filter}: {self.metric_name} norm") + + for j, p in enumerate(self.parameter_names[1:]): + for k, pp in enumerate(self.parameter_names[:j+1]): + sort = np.argsort(mismatch) + + ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1, rasterized = True) + + ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max())) + ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max())) + + + if k!=0: + ax[j,k].set_yticklabels([]) + + if j!=len(self.parameter_names)-2: + ax[j,k].set_xticklabels([]) + + ax[-1,k].set_xlabel(label_dic[pp]) + ax[j,0].set_ylabel(label_dic[p]) + + for cax in ax[j, j+1:]: + cax.set_axis_off() + + ax[0,-1].set_axis_on() + ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,) + ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed") + ax[0,-1].set_yticks([]) + + fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1]) + return fig, ax + + def print_correlations(self, + filter: str,): + mismatch = self.mismatch[filter] + print(f"\n \n \nCorrelations for filter {filter}:\n") + corrcoeff = [] + for j, p in enumerate(self.parameter_names): + print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], mismatch)[0,1]}") + + def get_error_distribution(self): + error_distribution = {} + for j, p in enumerate(self.parameter_names): + p_array = self.test_X_raw[:,j] + #bins = (np.array(self.parameter_distributions[p][:-1]) + np.array(self.parameter_grid[p][1:]))/2 + #bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] + bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12) + # calculate the error histogram with mismatch as weights + error_distribution[p], _ = np.histogram(p_array, weights = self.mismatch["total"], bins = bins, density = True) + error_distribution[p] = error_distribution[p]/np.sum(error_distribution[p]) + + self.error_distribution = error_distribution + + + def plot_worst_lightcurves(self,): + + fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) + + for cax, filt in zip(ax, self.filters): + ind = np.argmax(self.mismatch[filt.name]) + prediction = self.prediction_y_raw[filt.name][ind] + cax.plot(self.times, prediction, color = "blue") + cax.fill_between(self.times, prediction-1, prediction+1, color = "blue", alpha = 0.2) + cax.plot(self.times, self.test_y_raw[filt.name][ind], color = "red") + cax.invert_yaxis() + cax.set(xlabel = "$t$ in days", ylabel = "mag", xscale = "log", xlim = (self.times[0], self.times[-1])) + cax.set_title(f"{filt.name}", loc = "right", pad = -20) + cax.text(0, 0.05, np.array_str(self.test_X_raw[ind], precision = 2), transform = cax.transAxes, fontsize = 7) + + return fig, ax + + def plot_error_over_time(self,): + + fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) + + for cax, filt in zip(ax, self.filters): + error = np.abs(self.prediction_y_raw[filt.name] - self.test_y_raw[filt.name]) + indices = np.linspace(5, len(self.times)-1, 10).astype(int) + cax.violinplot(error[:, indices], positions = self.times[indices], widths = self.times[indices]/3) + cax.set(xlabel = "$t$ in days", ylabel = "error in mag", xscale = "log", xlim = (self.times[0], self.times[-1]), ylim = (0,1.5)) + cax.set_title(f"{filt.name}", loc = "right", pad = -20) + return fig, ax + + def plot_error_distribution(self,): + mismatch = self.mismatch["total"] + + fig, ax = plt.subplots(len(self.parameter_names), 1, figsize = (4, 18)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.09, right = 0.95) + + for j, cax in enumerate(ax): + p_array = self.test_X_raw[:,j] + p = self.parameter_names[j] + bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12) + + cax.hist(p_array, weights = self.mismatch["total"], color = "blue", bins = bins, density = True, histtype = "step") + cax.set_xlabel(self.parameter_names[j]) + cax.set_yticks([]) + + return fig, ax + + + + diff --git a/src/fiesta/train/Benchmarker.py b/src/fiesta/train/BenchmarkerLightcurves.py similarity index 68% rename from src/fiesta/train/Benchmarker.py rename to src/fiesta/train/BenchmarkerLightcurves.py index d293b51..dd6b117 100644 --- a/src/fiesta/train/Benchmarker.py +++ b/src/fiesta/train/BenchmarkerLightcurves.py @@ -1,4 +1,4 @@ -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiesta.inference.lightcurve_model import LightcurveModel import afterglowpy as grb from fiesta.constants import days_to_seconds from fiesta import conversions @@ -31,6 +31,8 @@ def __init__(self, name: str, model_dir: str, filters: list[str], + parameter_grid: dict, + MODEL = LightcurveModel, n_test_data: int = 3000, remake_test_data: bool = False, metric_name: str = "$\\mathcal{L}_\\inf$", @@ -40,16 +42,15 @@ def __init__(self, self.name = name self.model_dir = model_dir self.load_filters(filters) - self.model = AfterglowpyLightcurvemodel(name = self.name, - directory = self.model_dir, - filters = filters) + self.model = MODEL(name = self.name, + directory = self.model_dir, + filters = filters) self.times = self.model.times self._times_afterglowpy = self.times * days_to_seconds self.jet_type = jet_type self.parameter_names = self.model.metadata["parameter_names"] - train_X_raw = np.load(self.model_dir+"raw_data_training.npz")["X_raw"] - self.parameter_boundaries = np.array([np.min(train_X_raw, axis = 0), np.max(train_X_raw, axis = 0)]) + self.parameter_grid = parameter_grid if os.path.exists(self.model_dir+"/raw_data_test.npz") and not remake_test_data: self.load_test_data() @@ -57,12 +58,15 @@ def __init__(self, self.get_test_data(n_test_data) self.metric_name = metric_name - mask = np.logical_and(self.times>8, self.times<800) + mask = np.logical_and(self.times>1, self.times<1000) if metric_name == "$\\mathcal{L}_2$": self.metric = lambda y: np.sqrt(trapezoid(x= self.times[mask],y=y[mask]**2)) else: self.metric = lambda y: np.max(np.abs(y[mask])) - + + self.calculate_mismatch() + self.get_error_distribution() + def __repr__(self) -> str: return f"Surrogate_Benchmarker(name={self.name}, model_dir={self.model_dir})" @@ -81,7 +85,7 @@ def get_test_data(self, n_test_data): print(f"Determining test data for {n_test_data} random points within parameter grid.") for j in tqdm.tqdm(range(n_test_data)): - test_X_raw[j] = np.random.uniform(low = self.parameter_boundaries[0], high = self.parameter_boundaries[1]) + test_X_raw[j] = np.random.uniform(low = [self.parameter_grid[p][0] for p in self.parameter_names], high = [self.parameter_grid[p][-1] for p in self.parameter_names]) param_dict = {name: x for name, x in zip(self.parameter_names, test_X_raw[j])} prediction = self.model.predict(param_dict) @@ -134,7 +138,6 @@ def _call_afterglowpy(self, Z["xi_N"] = params_dict.get("xi_N", 1.0) Z["E0"] = 10 ** params_dict["log10_E0"] - Z["thetaCore"] = params_dict["thetaCore"] Z["n0"] = 10 ** params_dict["log10_n0"] Z["p"] = params_dict["p"] Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] @@ -144,24 +147,30 @@ def _call_afterglowpy(self, Z["thetaObs"] = params_dict["inclination_EM"] else: Z["thetaObs"] = params_dict["thetaObs"] - if self.jet_type == 1 or self.jet_type == 4: + + if self.jet_type == -1: + Z["thetaCore"] = params_dict["thetaCore"] + + if "thetaWing" in list(params_dict.keys()): #for Gaussian and power law jets + Z["thetaWing"] = params_dict["thetaWing"] + Z["thetaCore"] = params_dict["xCore"]*params_dict["thetaWing"] + + if self.jet_type == 4: Z["b"] = params_dict["b"] - if "thetaWing" in list(params_dict.keys()): - Z["thetaWing"] = params_dict["thetaWing"] # Afterglowpy returns flux in mJys mJys = grb.fluxDensity(self._times_afterglowpy, params_dict["nu"], **Z) return mJys - def calculate_mismatch(self, filter): - mismatch = np.empty(self.n_test_data) - - for j in range(self.n_test_data): - mismatch[j] = self.metric(self.prediction_y_raw[filter][j] - self.test_y_raw[filter][j]) - - return mismatch - + def calculate_mismatch(self): + mismatch = {} + for filt in self.filters: + array = np.empty(self.n_test_data) + for j in range(self.n_test_data): + array[j] = self.metric(self.prediction_y_raw[filt.name][j] - self.test_y_raw[filt.name][j]) + mismatch[filt.name] = array + self.mismatch = mismatch def plot_lightcurves_mismatch(self, filter: str, @@ -173,10 +182,10 @@ def plot_lightcurves_mismatch(self, vline = np.sqrt(trapezoid(x = self.times, y = np.ones(len(self.times)))) else: bins = np.arange(0, 3, 0.5) - vmin, vmax = 0, 3 + vmin, vmax = 0, 2 vline = 1. - mismatch = self.calculate_mismatch(filter) + mismatch = self.mismatch[filter] cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) colored_mismatch = cmap(mismatch/vmax) @@ -219,25 +228,66 @@ def plot_lightcurves_mismatch(self, def print_correlations(self, filter: str,): - mismatch = self.calculate_mismatch(filter) + mismatch = self.mismatch[filter] print(f"\n \n \nCorrelations for filter {filter}:\n") corrcoeff = [] for j, p in enumerate(self.parameter_names): print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], mismatch)[0,1]}") + + def get_error_distribution(self): + + error_distribution = {filt.name: {} for filt in self.filters} + + for filt in self.filters: + for j, p in enumerate(self.parameter_names): + p_array = self.test_X_raw[:,j] + bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 + bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] + # calculate the error histogram with mismatch as weights + error_distribution[filt.name][p], _ = np.histogram(p_array, weights = self.mismatch[filt.name], bins = bins, density = True) + error_distribution[filt.name][p] = error_distribution[filt.name][p]/np.sum(error_distribution[filt.name][p]) + + self.error_distribution = error_distribution + + + def plot_worst_lightcurves(self,): + + fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) + + for cax, filt in zip(ax, self.filters): + ind = np.argmax(self.mismatch[filt.name]) + prediction = self.prediction_y_raw[filt.name][ind] + cax.plot(self.times, prediction, color = "blue") + cax.fill_between(self.times, prediction-1, prediction+1, color = "blue", alpha = 0.2) + cax.plot(self.times, self.test_y_raw[filt.name][ind], color = "red") + cax.invert_yaxis() + cax.set(xlabel = "$t$ in days", ylabel = "mag") + cax.set_title(f"{filt.name}", loc = "right", pad = -20) + + return fig, ax - def plot_worst_lightcurve(self,filter): - - mismatch = self.calculate_mismatch(filter) - ind = np.argsort(mismatch)[-1] + def plot_error_distribution(self, filter): + mismatch = self.mismatch[filter] + + fig, ax = plt.subplots(len(self.parameter_names), 1, figsize = (4, 18)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.09, right = 0.95) + + for j, cax in enumerate(ax): + p_array = self.test_X_raw[:,j] + p = self.parameter_names[j] + bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 + bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] + + cax.hist(p_array, weights = self.mismatch[filter], color = "blue", bins = bins, density = True, histtype = "step") + cax.set_xlabel(self.parameter_names[j]) + cax.set_yticks([]) - fig, ax = plt.subplots(1,1) - fig.suptitle(f"{filter}") - ax.plot(self.times, self.prediction_y_raw[filter][ind], color = "blue") - ax.fill_between(self.times, self.prediction_y_raw[filter][ind]-1, self.prediction_y_raw[filter][ind]+1, color = "blue", alpha = 0.2) - ax.plot(self.times, self.test_y_raw[filter][ind], color = "red") - plt.gca().invert_yaxis() - ax.set(xlabel = "$t$ in days", ylabel = "mag") return fig, ax + + + + diff --git a/src/fiesta/train/FluxTrainer.py b/src/fiesta/train/FluxTrainer.py new file mode 100644 index 0000000..5943307 --- /dev/null +++ b/src/fiesta/train/FluxTrainer.py @@ -0,0 +1,333 @@ +"""Method to train the surrogate models""" + +import os +import numpy as np + +import jax +import jax.numpy as jnp +from jaxtyping import Array, Float, Int + +from fiesta.utils import MinMaxScalerJax, StandardScalerJax, PCAdecomposer +from fiesta import utils +from fiesta import conversions +from fiesta import models_utilities +import fiesta.train.neuralnets as fiesta_nn + +import matplotlib.pyplot as plt +import pickle +import h5py +from typing import Callable + + +class FluxTrainer: + """Abstract class for training a collection of surrogate""" + + name: str + outdir: str + parameter_names: list[str] + + preprocessing_metadata: dict[str, dict[str, float]] + + X_raw: Float[Array, "n_batch n_params"] + y_raw: dict[str, Float[Array, "n_batch n_times"]] + + X: Float[Array, "n_batch n_input_surrogate"] + y: dict[str, Float[Array, "n_batch n_output_surrogate"]] + + trained_states: dict[str, fiesta_nn.TrainState] + + def __init__(self, + name: str, + outdir: str, + plots_dir: str = None, + ) -> None: + + self.name = name + self.outdir = outdir + # Check if directories exists, otherwise, create: + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + + self.plots_dir = plots_dir + if not os.path.exists(self.plots_dir): + os.makedirs(self.plots_dir) + + + # To be loaded by child classes + self.parameter_names = None + + self.preprocessing_metadata = {} + + self.train_X_raw = None + self.train_y_raw = None + + self.val_X_raw = None + self.val_y_raw = None + + def __repr__(self) -> str: + return f"FluxTrainer(name={self.name})" + + def preprocess(self): + + print("Preprocessing data by scaling to mean 0 and std 1. . .") + self.X_scaler = StandardScalerJax() + self.X = self.X_scaler.fit_transform(self.train_X_raw) + + self.y_scaler = StandardScalerJax() + self.y = self.y_scaler.fit_transform(self.train_y_raw) + + # Save the metadata + self.preprocessing_metadata["X_scaler"] = self.X_scaler + self.preprocessing_metadata["y_scaler"] = self.y_scaler + print("Preprocessing data . . . done") + + def fit(self, + config: fiesta_nn.NeuralnetConfig = None, + key: jax.random.PRNGKey = jax.random.PRNGKey(0), + verbose: bool = True): + """ + The config controls which architecture is built and therefore should not be specified here. + + Args: + config (nn.NeuralnetConfig, optional): _description_. Defaults to None. + """ + + # Get default choices if no config is given + if config is None: + config = fiesta_nn.NeuralnetConfig() + self.config = config + + input_ndim = len(self.parameter_names) + + + # Create neural network and initialize the state + net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) + key, subkey = jax.random.split(key) + state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) + + # Perform training loop + state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y, self.val_X, self.val_y, verbose=verbose) + # Plot and save the plot if so desired + if self.plots_dir is not None: + plt.figure(figsize=(10, 5)) + ls = "-o" + ms = 3 + plt.plot([i+1 for i in range(len(train_losses))], train_losses, ls, markersize=ms, label="Train", color="red") + plt.plot([i+1 for i in range(len(val_losses))], val_losses, ls, markersize=ms, label="Validation", color="blue") + plt.legend() + plt.xlabel("Epoch") + plt.ylabel("MSE loss") + plt.yscale('log') + plt.title("Learning curves") + plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{self.name}.png")) + plt.close() + + self.trained_state = state + + def save(self): + """ + Save the trained model and all the used metadata to the outdir. + """ + # Save the metadata + if not os.path.exists(self.outdir): + os.makedirs(self.outdir) + meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl") + + save = {} + save["times"] = self.times + save["nus"] = self.nus + save["parameter_names"] = self.parameter_names + save.update(self.preprocessing_metadata) + + with open(meta_filename, "wb") as meta_file: + pickle.dump(save, meta_file) + + # Save the NN + model = self.trained_state + fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{self.name}.pkl") + + def _save_preprocessed_data(self): + print("Saving preprocessed data . . .") + np.savez(os.path.join(self.outdir, "afterglow_preprocessed_data.npz"), train_X=self.train_X, train_y= self.train_y, val_X = self.val_X, val_y = self.val_y) + print("Saving preprocessed data . . . done") + +class PCATrainer(FluxTrainer): + + def __init__(self, + name: str, + outdir: str, + data_manager, + n_pca: Int = 100, + plots_dir: str = None, + save_preprocessed_data: bool = False): + + super().__init__(name = name, + outdir = outdir, + plots_dir = plots_dir) + + self.n_pca = n_pca + self.save_preprocessed_data = save_preprocessed_data + self.data_manager = data_manager + self.parameter_names = data_manager.parameter_names + self.times = data_manager.times + self.nus = data_manager.nus + + self.plots_dir = plots_dir + if self.plots_dir is not None and not os.path.exists(self.plots_dir): + os.makedirs(self.plots_dir) + + self.preprocess() + + if save_preprocessed_data: + self._save_preprocessed_data() + + def preprocess(self): + print(f"Fitting PCA model with {self.n_pca} components to the provided data.") + self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_data_from_file(self.n_pca) + print(f"PCA model accounts for a share {np.sum(self.y_scaler.explained_variance_ratio_)} of the total variance in the training data. This value is hopefully close to 1.") + self.preprocessing_metadata["X_scaler"] = self.X_scaler + self.preprocessing_metadata["y_scaler"] = self.y_scaler + print("Preprocessing data . . . done") + + def load_parameter_names(self): + raise NotImplementedError + + def load_times(self): + raise NotImplementedError + + def load_raw_data(self): + raise NotImplementedError + +class DataManager: + + def __init__(self, + file: str, + n_training: Int, + n_val: Int, + tmin: Float, + tmax: Float, + numin: Float = 1e9, + numax: Float = 2.5e18, + special_training: list = [], + ): + + self.file = file + self.n_training = n_training + self.n_val = n_val + + self.tmin = tmin + self.tmax = tmax + self.numin = numin + self.numax = numax + + self.special_training = special_training + + self.read_metadata_from_file() + self.set_up_domain_mask() + + def read_metadata_from_file(self,)->None: + with h5py.File(self.file, "r") as f: + self.times_data = f["times"][:] + self.nus_data = f["nus"][:] + self.parameter_names = f["parameter_names"][:].astype(str).tolist() + self.n_training_exists = f["train"]["X"].shape[0] + self.n_val_exists = f["val"]["X"].shape[0] + + def set_up_domain_mask(self,)->None: + """Trims the stored data down to the time and frequency range desired for training.""" + + if self.tminself.times_data.max(): + print(f"\nWarning: provided time range {self.tmin, self.tmax} is too wide for the data stored in file. Using range {max(self.times_data.min(), self.tmin), min(self.times_data.max(), self.tmax)} instead.\n") + time_mask = np.logical_and(self.times_data>=self.tmin, self.times_data<=self.tmax) + self.times = self.times_data[time_mask] + self.n_times = len(self.times) + + if self.numinself.nus_data.max(): + print(f"\nWarning: provided frequency range {self.numin, self.numax} is too wide for the data stored in file. Using range {max(self.nus_data.min(), self.numin), min(self.nus_data.max(), self.numax)} instead.\n") + nu_mask = np.logical_and(self.nus_data>=self.numin, self.nus_data<=self.numax) + self.nus = self.nus_data[nu_mask] + self.n_nus = len(self.nus) + + mask = nu_mask[:, None] & time_mask + self.mask = mask.flatten() + + def get_data_from_file(self,): + with h5py.File(self.file, "r") as f: + if self.n_training>self.n_training_exists: + raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") + self.train_X_raw = f["train"]["X"][:self.n_training] + self.train_y_raw = f["train"]["y"][:self.n_training, self.mask] + + for label in self.special_training: + self.train_X_raw = np.concatenate((self.train_X_raw, f["special_train"][label]["X"][:])) + self.train_y_raw = np.concatenate((self.train_y_raw, f["special_train"][label]["y"][:, self.mask])) + + if self.n_val>self.n_val_exists: + raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to validate with {self.n_val} data points.") + self.val_X_raw = f["val"]["X"][:self.n_val] + self.val_y_raw = f["val"]["y"][:self.n_val, self.mask] + + def preprocess_data_from_file(self, n_components: int)->None: + Xscaler, yscaler = StandardScalerJax(), PCAdecomposer(n_components=n_components) + with h5py.File(self.file, "r") as f: + # preprocess the training data + if self.n_training>self.n_training_exists: + raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") + + train_X_raw = f["train"]["X"][:self.n_training] + for label in self.special_training: + train_X_raw = np.concatenate((train_X_raw, f["special_train"][label]["X"][:])) + train_X = Xscaler.fit_transform(train_X_raw) + + loaded = f["train"]["y"][:15_000, self.mask] + if np.any(np.isinf(loaded)): + raise ValueError(f"Found inftys in training data.") + yscaler.fit(loaded) # only load 15k cause otherwise the array might get too large + train_y = np.empty((self.n_training, n_components)) + n_loaded = 0 + for chunk in f["train"]["y"].iter_chunks(): + loaded = f["train"]["y"][chunk][:, self.mask] + if np.any(np.isinf(loaded)): + raise ValueError(f"Found inftys in training data.") + train_y[n_loaded:n_loaded+len(loaded)] = yscaler.transform(loaded) + n_loaded += len(loaded) + if n_loaded >= self.n_training: + break + for label in self.special_training: + special_train_y = yscaler.transform(f["special_train"][label]["y"][:, self.mask]) + train_y = np.concatenate((train_y, special_train_y)) + + # preprocess validation data + if self.n_val>self.n_val_exists: + raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to train with {self.n_val} data points.") + val_X_raw = f["val"]["X"][:self.n_val] + val_X = Xscaler.transform(val_X_raw) + val_y_raw = f["val"]["y"][:self.n_val, self.mask] + val_y = yscaler.transform(val_y_raw) + + return train_X, train_y, val_X, val_y, Xscaler, yscaler + + + def pass_data(self, object): + object.parameter_names = self.parameter_names + object.train_X_raw = self.train_X_raw + object.train_y_raw = self.train_y_raw + object.val_X_raw = self.val_X_raw + object.val_y_raw = self.val_y_raw + object.times = self.times + object.nus = self.nus + + + def print_file_info(self,): + with h5py.File(self.file, "r") as f: + print(f"Times: {f['times'][0]} {f['times'][-1]}") + print(f"Nus: {f['nus'][0]} {f['nus'][-1]}") + print(f"Parameter distributions: {f['parameter_distributions'][()].decode('utf-8')}") + print("\n") + print(f"Training data: {self.n_training_exists}") + print(f"Validation data: {self.n_val_exists}") + print(f"Test data: {f['test']['X'].shape[0]}") + print("Special data:") + for key in f['special_train'].keys(): + print(f"\t {key}: {f['special_train'][key]['X'].shape[0]} description: {f['special_train'][key].attrs['comment']}") + print("\n \n") \ No newline at end of file diff --git a/src/fiesta/train/SurrogateTrainer.py b/src/fiesta/train/LightcurveTrainer.py similarity index 90% rename from src/fiesta/train/SurrogateTrainer.py rename to src/fiesta/train/LightcurveTrainer.py index b703fa9..e0f486a 100644 --- a/src/fiesta/train/SurrogateTrainer.py +++ b/src/fiesta/train/LightcurveTrainer.py @@ -29,7 +29,6 @@ class SurrogateTrainer: filters: list[Filter] parameter_names: list[str] - validation_fraction: Float preprocessing_metadata: dict[str, dict[str, float]] # TODO: why do we have so many datasets? @@ -56,7 +55,6 @@ class SurrogateTrainer: def __init__(self, name: str, outdir: str, - validation_fraction: Float = 0.2, save_raw_data: bool = False, save_preprocessed_data: bool = False) -> None: @@ -72,9 +70,6 @@ def __init__(self, # To be loaded by child classes self.filters = None self.parameter_names = None - self.plots_dir = None - - self.validation_fraction = validation_fraction self.preprocessing_metadata = {} self.X_raw = None @@ -82,6 +77,7 @@ def __init__(self, self.X = None self.y = None + self.weights = None def __repr__(self) -> str: return f"SurrogateTrainer(name={self.name})" @@ -126,7 +122,9 @@ def fit(self, input_ndim = len(self.parameter_names) for filt in self.filters: - + + print(f"\n\n Training {filt.name}... \n\n") + # Create neural network and initialize the state net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) key, subkey = jax.random.split(key) @@ -244,17 +242,18 @@ def __init__(self, save_preprocessed_data: If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. """ - super().__init__(name=name, outdir=outdir, validation_fraction=validation_fraction) + super().__init__(name=name, outdir=outdir) self.plots_dir = plots_dir if self.plots_dir is not None and not os.path.exists(self.plots_dir): os.makedirs(self.plots_dir) + + self.validation_fraction = validation_fraction self.svd_ncoeff = svd_ncoeff self.tmin = tmin self.tmax = tmax self.dt = dt - self.plots_dir = plots_dir self.save_raw_data = save_raw_data self.save_preprocessed_data = save_preprocessed_data @@ -300,6 +299,13 @@ def preprocess(self): for filt in tqdm.tqdm(self.filters): y_scaler = MinMaxScalerJax() + + completely_problematic = np.where(np.all(np.isinf(self.train_y_raw[filt.name]), axis = 1))[0] + problematic = np.unique(np.where(np.isinf(self.train_y_raw[filt.name]))[0]) + + if len(problematic)!=0: + raise Exception(f"There were infs in the magnitudes for filter {filt.name}.") + data = y_scaler.fit_transform(self.train_y_raw[filt.name]) # Do SVD decomposition on the training data @@ -498,6 +504,7 @@ def __init__(self, outdir: str, filters: list[str], parameter_grid: dict[str, list[float]], + weight_grids: dict = None, n_training_data: Int = 5000, jet_type: Int = -1, fixed_parameters: dict[str, Float] = {}, @@ -536,6 +543,10 @@ def __init__(self, self.n_times = n_times dt = (tmax - tmin) / n_times self.parameter_grid = parameter_grid + self.weight_grids = weight_grids + if self.weight_grids is None: + self.weight_grids = {p: np.full_like(self.parameter_grid[p], 1/len(self.parameter_grid[p])) for p in self.parameter_grid.keys()} + self.fixed_parameters = fixed_parameters self.use_log_spacing = use_log_spacing @@ -561,6 +572,7 @@ def __init__(self, plots_dir=plots_dir, save_raw_data=save_raw_data, save_preprocessed_data=save_preprocessed_data) + def load_filters(self, filters: list[str]): self.filters = [] @@ -595,23 +607,21 @@ def create_raw_data(self): TODO: for now we train per filter, but best to change this! """ # Create training data + nus = jnp.array([filt.nu for filt in self.filters]) X_raw = np.empty((self.n_training_data, len(self.parameter_names))) y_raw = {filt.name: np.empty((self.n_training_data, len(self.times))) for filt in self.filters} for j, key in enumerate(self.parameter_grid.keys()): - X_raw[:,j] = np.random.choice(self.parameter_grid[key], size = self.n_training_data, replace = True) + X_raw[:,j] = np.random.choice(self.parameter_grid[key], size = self.n_training_data, replace = True, p = self.weight_grids[key]) print(f"Creating the afterglowpy training dataset on grid with {self.n_training_data} points.") for i in tqdm.tqdm(range(self.n_training_data)): - for filt in self.filters: - param_dict = dict(zip(self.parameter_names, X_raw[i])) - param_dict.update(self.fixed_parameters) - param_dict["nu"] = filt.nu # Add nu per filter before calling afterglowpy - - # Create and save output - mJys = self._call_afterglowpy(param_dict) - y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys) + param_dict = dict(zip(self.parameter_names, X_raw[i])) + param_dict.update(self.fixed_parameters) + mJys = self._call_afterglowpy(param_dict, nus) + for k, filt in enumerate(self.filters): + y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys[k]) self.train_X_raw = X_raw self.train_y_raw = y_raw @@ -623,16 +633,13 @@ def create_raw_data(self): print(f"Creating the afterglowpy validation dataset on {self.n_val_data} random points within grid.") for i in tqdm.tqdm(range(self.n_val_data)): - X_raw[i] = [np.random.uniform(self.parameter_grid[p][0], self.parameter_grid[p][-1]) for p in self.parameter_names] + X_raw[i] = [np.random.uniform(self.parameter_grid[p][0], self.parameter_grid[p][-1]) for p in self.parameter_names] + param_dict = dict(zip(self.parameter_names, X_raw[i])) + param_dict.update(self.fixed_parameters) + mJys = self._call_afterglowpy(param_dict, nus) - for filt in self.filters: - param_dict = dict(zip(self.parameter_names, X_raw[i])) - param_dict.update(self.fixed_parameters) - param_dict["nu"] = filt.nu # Add nu per filter before calling afterglowpy - - # Create and save output - mJys = self._call_afterglowpy(param_dict) - y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys) + for k, filt in enumerate(self.filters): + y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys[k]) self.val_X_raw = X_raw self.val_y_raw = y_raw @@ -643,23 +650,35 @@ def _read_files(self,): raw_data_train = np.load(self.outdir+"/raw_data_training.npz") raw_data_validation = np.load(self.outdir+'/raw_data_validation.npz') - self.n_training_data = 4000 - self.n_val_data = 1000 - training_y_raw = {} val_y_raw = {} - select = np.random.choice(range(0, len(raw_data_train["X_raw"])), size = self.n_training_data, replace = False) - select_val = np.random.choice(range(0, len(raw_data_validation["X_raw"])), size = self.n_val_data, replace = False) - for filt in self.filters: - training_y_raw[filt.name] = raw_data_train[filt.name][select] - val_y_raw[filt.name] = raw_data_validation[filt.name][select_val] - return raw_data_train["X_raw"][select], training_y_raw, raw_data_validation["X_raw"][select_val], val_y_raw + training_y_raw[filt.name] = raw_data_train[filt.name] + val_y_raw[filt.name] = raw_data_validation[filt.name] + + self.n_training_data = len(raw_data_train["X_raw"]) + self.n_val_data = len(raw_data_validation["X_raw"]) + return raw_data_train["X_raw"], training_y_raw, raw_data_validation["X_raw"], val_y_raw + + def _set_weights(self,): # TODO: dev legacy + weights = {filt.name: jnp.ones(self.n_training_data) for filt in self.filters} + + + for j, p in enumerate(self.parameter_names): + bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 + bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]+0.1] # the right boundary of the bins needs to be a little larger because of digitize + indices = np.digitize(self.train_X_raw[:,j], bins) - 1 + for filt in self.filters: + weights[filt.name] *= self.weight_grids[filt.name][p][indices] + + self.weights = {filt.name: weights[filt.name]/np.sum(weights[filt.name]) for filt in self.filters} + def _call_afterglowpy(self, - params_dict: dict[str, Float]) -> Float[Array, "n_times"]: + params_dict: dict[str, Float], + nus) -> Float[Array, "n_times"]: """ Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. The output will be a set of mJys. @@ -677,7 +696,6 @@ def _call_afterglowpy(self, Z["xi_N"] = params_dict.get("xi_N", 1.0) Z["E0"] = 10 ** params_dict["log10_E0"] - Z["thetaCore"] = params_dict["thetaCore"] Z["n0"] = 10 ** params_dict["log10_n0"] Z["p"] = params_dict["p"] Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] @@ -687,11 +705,18 @@ def _call_afterglowpy(self, Z["thetaObs"] = params_dict["inclination_EM"] else: Z["thetaObs"] = params_dict["thetaObs"] - if self.jet_type == 1 or self.jet_type == 4: + + if self.jet_type == -1: + Z["thetaCore"] = params_dict["thetaCore"] + + if "thetaWing" in list(params_dict.keys()): #for Gaussian and power law jets + Z["thetaWing"] = params_dict["thetaWing"] + Z["thetaCore"] = params_dict["xCore"]*params_dict["thetaWing"] + + if self.jet_type == 4: Z["b"] = params_dict["b"] - if "thetaWing" in list(params_dict.keys()): - Z["thetaWing"] = params_dict["thetaWing"] # Afterglowpy returns flux in mJys - mJys = grb.fluxDensity(self._times_afterglowpy, params_dict["nu"], **Z) + tt, nunu = np.meshgrid(self._times_afterglowpy, nus) + mJys = grb.fluxDensity(tt, nunu, **Z) return mJys \ No newline at end of file diff --git a/src/fiesta/utils.py b/src/fiesta/utils.py index 8cafad3..9b3347f 100644 --- a/src/fiesta/utils.py +++ b/src/fiesta/utils.py @@ -1,6 +1,7 @@ import jax.numpy as jnp from jax.scipy.stats import truncnorm from jaxtyping import Array, Float +import jax import numpy as np import pandas as pd import scipy.interpolate as interp @@ -12,6 +13,11 @@ from sncosmo.bandpasses import _BANDPASSES, _BANDPASS_INTERPOLATORS import sncosmo + +#################### +### DATA SCALERS ### +#################### + class MinMaxScalerJax(object): """ MinMaxScaler like sklearn does it, but for JAX arrays since sklearn might not be JAX-compatible? @@ -40,6 +46,102 @@ def fit_transform(self, x: Array) -> Array: self.fit(x) return self.transform(x) + +class StandardScalerJax(object): + """ + StandardScaler like sklearn does it, but for JAX arrays since sklearn might not be JAX-compatible? + + Note: assumes that input has dynamical range: it will not catch errors due to constant input (leading to zero division) + """ + + def __init__(self, + mu: Array = None, + sigma: Array = None): + + self.mu = mu + self.sigma = sigma + + def fit(self, x: Array) -> None: + self.mu = jnp.average(x, axis=0) + self.sigma = jnp.std(x, axis=0) + + def transform(self, x: Array) -> Array: + return (x - self.mu) / self.sigma + + def inverse_transform(self, x: Array) -> Array: + return x * self.sigma + self.mu + + def fit_transform(self, x: Array) -> Array: + self.fit(x) + return self.transform(x) + +class PCAdecomposer(object): + """ + PCA decomposer like sklearn does it. Based on https://github.com/alonfnt/pcax/tree/main. + """ + def __init__(self, n_components: int, solver: str = "randomized"): + self.n_components = n_components + self.solver = solver + + def fit(self, x: Array)-> None: + if self.solver == "full": + self._fit_full(x, self.n_components) + elif self.solver == "randomized": + rng = jax.random.PRNGKey(self.n_components) + self._fit_randomized(x, rng) + else: + raise ValueError("solver parameter is not correct") + + def _fit_full(self, x: Array): + n_samples, n_features = x.shape + self.means = jnp.mean(x, axis=0, keepdims=True) + x = x - self.means + + _, S, Vt = jax.scipy.linalg.svd(x, full_matrices= False) + + self.explained_variance_ = (S[:self.n_components] ** 2) / (n_samples - 1) + total_var = jnp.sum(S ** 2) / (n_samples - 1) + self.explained_variance_ratio_ = self.explained_variance_ / total_var + + self.Vt = Vt[:self.n_components] + + def _fit_randomized(self, x: Array, rng, n_iter = 5): + """Randomized PCA based on Halko et al [https://doi.org/10.48550/arXiv.1007.5510].""" + n_samples, n_features = x.shape + self.means = jnp.mean(x, axis=0, keepdims=True) + x = x - self.means + + # Generate n_features normal vectors of the given size + size = jnp.minimum(2 * self.n_components, n_features) + Q = jax.random.normal(rng, shape=(n_features, size)) + + def step_fn(q, _): + q, _ = jax.scipy.linalg.lu(x @ q, permute_l=True) + q, _ = jax.scipy.linalg.lu(x.T @ q, permute_l=True) + return q, None + + Q, _ = jax.lax.scan(step_fn, init=Q, xs=None, length=n_iter) + Q, _ = jax.scipy.linalg.qr(x @ Q, mode="economic") + B = Q.T @ x + + _, S, Vt = jax.scipy.linalg.svd(B, full_matrices=False) + + self.explained_variance_ = (S[:self.n_components] ** 2) / (n_samples - 1) + total_var = jnp.sum(S ** 2) / (n_samples - 1) + self.explained_variance_ratio_ = self.explained_variance_ / total_var + + self.Vt = Vt[:self.n_components] + + def transform(self, x: Array)->Array: + return jnp.dot(x - self.means, self.Vt.T) + + def inverse_transform(self, x: Array)->Array: + return jnp.dot(x, self.Vt) + self.means + + def fit_transform(self, x: Array)-> Array: + self.fit(x) + return self.transform(x) + def inverse_svd_transform(x: Array, VA: Array, nsvd_coeff: int = 10) -> Array: @@ -222,6 +324,19 @@ def load_event_data(filename): return data +def write_event_data(filename: str, data: dict): + """ + Takes a magnitude dict and writes it to filename. + The magnitude dict should have filters as keys, the arrays should have the structure [[mjd, mag, err]]. + """ + with open(filename, "w") as out: + for filt in data.keys(): + for data_point in data[filt]: + time = Time(data_point[0], format = "mjd") + filt_name = filt.replace("_", ":") + line = f"{time.isot} {filt_name} {data_point[1]:f} {data_point[2]:f}" + out.write(line +"\n") + ######################### ### Filters ### ######################### diff --git a/src/fiestaEM.egg-info/PKG-INFO b/src/fiestaEM.egg-info/PKG-INFO new file mode 100644 index 0000000..a59f6cf --- /dev/null +++ b/src/fiestaEM.egg-info/PKG-INFO @@ -0,0 +1,58 @@ +Metadata-Version: 2.1 +Name: fiestaEM +Version: 0.0.1 +Summary: Fast inference of electromagnetic signals with JAX +Home-page: https://github.com/ThibeauWouters/fiestaEM +Author: Thibeau Wouters +Author-email: thibeauwouters@gmail.com +License: MIT +Keywords: sampling,inference,astrophysics,kilonovae,gamma-ray bursts +Requires-Python: >=3.10 +Description-Content-Type: text/markdown +License-File: LICENSE +Requires-Dist: jax>=0.4.24 +Requires-Dist: jaxlib>=0.4.24 +Requires-Dist: numpy<2.0.0 +Requires-Dist: pandas<2.0.0 +Requires-Dist: jaxtyping +Requires-Dist: beartype +Requires-Dist: tqdm +Requires-Dist: scipy<=1.14.0 +Requires-Dist: ml_collections +Requires-Dist: astropy +Requires-Dist: sncosmo +Requires-Dist: flowMC +Requires-Dist: joblib + +# fiesta 🎉 + +`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x + +![fiesta logo](docs/fiesta_logo.jpeg) + +**NOTE:** `fiesta` is currently under development -- stay tuned! + +## Installation + +pip installation is currently work in progress. Install from source by cloning this Github repository and running +``` +pip install -e . +``` + +NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`). + +## Training surrogate models + +To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models` + +- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model. +- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure. + +## Examples + +- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model. +- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon. + +## Acknowledgements + +The logo was created by [ideogram AI](https://ideogram.ai/). diff --git a/src/fiestaEM.egg-info/SOURCES.txt b/src/fiestaEM.egg-info/SOURCES.txt new file mode 100644 index 0000000..52a1484 --- /dev/null +++ b/src/fiestaEM.egg-info/SOURCES.txt @@ -0,0 +1,9 @@ +LICENSE +README.md +pyproject.toml +setup.cfg +src/fiestaEM.egg-info/PKG-INFO +src/fiestaEM.egg-info/SOURCES.txt +src/fiestaEM.egg-info/dependency_links.txt +src/fiestaEM.egg-info/requires.txt +src/fiestaEM.egg-info/top_level.txt \ No newline at end of file diff --git a/src/fiestaEM.egg-info/dependency_links.txt b/src/fiestaEM.egg-info/dependency_links.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/fiestaEM.egg-info/dependency_links.txt @@ -0,0 +1 @@ + diff --git a/src/fiestaEM.egg-info/requires.txt b/src/fiestaEM.egg-info/requires.txt new file mode 100644 index 0000000..283afba --- /dev/null +++ b/src/fiestaEM.egg-info/requires.txt @@ -0,0 +1,13 @@ +jax>=0.4.24 +jaxlib>=0.4.24 +numpy<2.0.0 +pandas<2.0.0 +jaxtyping +beartype +tqdm +scipy<=1.14.0 +ml_collections +astropy +sncosmo +flowMC +joblib diff --git a/src/fiestaEM.egg-info/top_level.txt b/src/fiestaEM.egg-info/top_level.txt new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/src/fiestaEM.egg-info/top_level.txt @@ -0,0 +1 @@ + diff --git a/trained_models/.gitignore b/trained_models/.gitignore deleted file mode 100644 index 1dcc44e..0000000 --- a/trained_models/.gitignore +++ /dev/null @@ -1,9 +0,0 @@ -figures/ -raw_data.npz -preprocessed_data.npz -preprocessed_data_training.npz -preprocessed_data_validation.npz -raw_data_test.npz -raw_data_training.npz -raw_data_validation.npz - diff --git a/trained_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl b/trained_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl deleted file mode 100644 index 4356bf1..0000000 Binary files a/trained_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl and /dev/null differ diff --git a/trained_models/GRB/afterglowpy/tophat/bessellv.pkl b/trained_models/GRB/afterglowpy/tophat/bessellv.pkl deleted file mode 100644 index 79ef2ee..0000000 Binary files a/trained_models/GRB/afterglowpy/tophat/bessellv.pkl and /dev/null differ diff --git a/trained_models/GRB/afterglowpy/tophat/radio-3GHz.pkl b/trained_models/GRB/afterglowpy/tophat/radio-3GHz.pkl deleted file mode 100644 index 3863868..0000000 Binary files a/trained_models/GRB/afterglowpy/tophat/radio-3GHz.pkl and /dev/null differ diff --git a/trained_models/GRB/afterglowpy/tophat/radio-6GHz.pkl b/trained_models/GRB/afterglowpy/tophat/radio-6GHz.pkl deleted file mode 100644 index 02caac0..0000000 Binary files a/trained_models/GRB/afterglowpy/tophat/radio-6GHz.pkl and /dev/null differ diff --git a/trained_models/GRB/afterglowpy/tophat/tophat_metadata.pkl b/trained_models/GRB/afterglowpy/tophat/tophat_metadata.pkl deleted file mode 100644 index f30e6b8..0000000 Binary files a/trained_models/GRB/afterglowpy/tophat/tophat_metadata.pkl and /dev/null differ