diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml index c73e032..7176ffb 100644 --- a/.github/workflows/pylint.yml +++ b/.github/workflows/pylint.yml @@ -7,7 +7,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: ["3.8", "3.9", "3.10"] + python-version: ["3.10", "3.11", "3.12"] steps: - uses: actions/checkout@v4 - name: Set up Python ${{ matrix.python-version }} @@ -18,6 +18,7 @@ jobs: run: | python -m pip install --upgrade pip pip install pylint + python -m pip install . - name: Analysing the code with pylint run: | - pylint $(git ls-files '*.py') + pylint --fail-under=5 --disable=C,R,W $(git ls-files 'src/*.py') diff --git a/.github/workflows/unittest.yml b/.github/workflows/unittest.yml new file mode 100644 index 0000000..067887a --- /dev/null +++ b/.github/workflows/unittest.yml @@ -0,0 +1,40 @@ +name: Unittest + +on: + push: + pull_request: + +jobs: + test: + runs-on: ${{ matrix.os }} + defaults: + run: + shell: bash + strategy: + fail-fast: false + matrix: + os: ["ubuntu-latest"] + python-version: ["3.10", "3.11", "3.12"] + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python with Conda ${{ matrix.python-version }} + uses: conda-incubator/setup-miniconda@v3 + with: + python-version: ${{ matrix.python-version }} + architecture: ${{ 'x86_64' }} + miniforge-version: latest + use-mamba: true + mamba-version: "*" + activate-environment: fiesta_env + + - name: Install Python dependencies + run: | + python -m pip install --upgrade pip setuptools wheel + python -m pip install pytest pytest-cov pytest-aiohttp sqlparse freezegun PyJWT joblib coveralls + python -m pip install . + + - name: Perform tests with pytest + run: | + python -m coverage run --source fiesta -m pytest tests/*.py \ No newline at end of file diff --git a/.gitignore b/.gitignore index bee8a64..7763605 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,11 @@ __pycache__ +.coverage +*egg-info +/flux_models/*/benchmarks +/lightcurve_models/*/benchmarks +/flux_models/*/model/*.h5 +/lightcurve_models/*/model/*.h5 +*nohup.out +*.sbatch +*/log +*/err \ No newline at end of file diff --git a/benchmarks/GRB/benchmark_afterglowpy_tophat.py b/benchmarks/GRB/benchmark_afterglowpy_tophat.py index f2521ff..3467773 100644 --- a/benchmarks/GRB/benchmark_afterglowpy_tophat.py +++ b/benchmarks/GRB/benchmark_afterglowpy_tophat.py @@ -18,7 +18,6 @@ file_ending = "Linf" B = Benchmarker(name = name, - parameter_grid = parameter_grid, model_dir = model_dir, MODEL = AfterglowpyLightcurvemodel, filters = FILTERS, diff --git a/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py index 4f35b0c..6cd56d7 100644 --- a/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py +++ b/flux_models/afterglowpy_gaussian/benchmark_afterglowpy_gaussian.py @@ -1,45 +1,35 @@ -import numpy as np -import matplotlib.pyplot as plt - -from fiesta.train.BenchmarkerFluxes import Benchmarker -from fiesta.inference.lightcurve_model import AfterglowpyPCA -from fiesta.utils import Filter +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowFlux name = "gaussian" model_dir = f"./model/" FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] -for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: - if metric_name == "$\\mathcal{L}_2$": - file_ending = "L2" - else: - file_ending = "Linf" - - B = Benchmarker(name = name, - model_dir = model_dir, - MODEL = AfterglowpyPCA, - filters = FILTERS, - metric_name = metric_name, - ) - - - for filt in FILTERS: - - fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\\alpha_w$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) - fig.savefig(f"./benchmarks/benchmark_{filt}_{file_ending}.pdf", dpi = 200) +lc_model = AfterglowFlux(name, + directory = model_dir, + filters = FILTERS, + model_type= "MLP") + +for metric_name in ["L2", "Linf"]: + - B.print_correlations(filter = filt) + benchmarker = Benchmarker( + model = lc_model, + data_file = "./model/afterglowpy_raw_data.h5", + metric_name = metric_name + ) - fig, ax = B.plot_worst_lightcurves() - fig.savefig(f"./benchmarks/worst_lightcurves_{file_ending}.pdf", dpi = 200) - - -fig, ax = B.plot_error_distribution() -fig.savefig(f"./benchmarks/error_distribution.pdf", dpi = 200) - -fig, ax = B.plot_error_over_time() -fig.savefig(f"./benchmarks/error_over_time.pdf", dpi = 200) + benchmarker.benchmark() + + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\iota$", + "$\log_{10}(E_0)$", + "$\\theta_{\\mathrm{c}}$", + "$\\alpha_{\\mathrm{w}}$", + "$\log_{10}(n_{\mathrm{ism}})$", + "$p$", + "$\log_{10}(\\epsilon_e)$", + "$\log_{10}(\\epsilon_B)$"]) diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png deleted file mode 100644 index eca6b33..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_X-ray-1keV_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png deleted file mode 100644 index 9c48ca3..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_bessellv_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png deleted file mode 100644 index 52ebbfd..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-3GHz_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png b/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png deleted file mode 100644 index 6ea54ae..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/afterglowpy_gaussian_radio-6GHz_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf deleted file mode 100644 index 5a83f0d..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf deleted file mode 100644 index 9feb22c..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf deleted file mode 100644 index 1c1ed14..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf deleted file mode 100644 index 6e18a1b..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf deleted file mode 100644 index 6b12b59..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf deleted file mode 100644 index 7fc0988..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_bessellv_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf deleted file mode 100644 index e06c6ed..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf deleted file mode 100644 index b682087..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf deleted file mode 100644 index 56450fd..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-3GHz_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf deleted file mode 100644 index 4da5137..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf deleted file mode 100644 index 3d77ef1..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf deleted file mode 100644 index 5b83c50..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/benchmark_radio-6GHz_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf deleted file mode 100644 index 403316f..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/error_distribution.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf deleted file mode 100644 index ed53697..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf b/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf deleted file mode 100644 index 9972a81..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/error_over_time_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png b/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png deleted file mode 100644 index b82a885..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/learning_curves_gaussian.png and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf deleted file mode 100644 index 8d1f376..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf b/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf deleted file mode 100644 index d846900..0000000 Binary files a/flux_models/afterglowpy_gaussian/benchmarks/worst_lightcurves_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py index 6041f03..15387ec 100644 --- a/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py +++ b/flux_models/afterglowpy_gaussian/create_data_afterglowpy_gaussian.py @@ -1,19 +1,17 @@ import numpy as np -import matplotlib.pyplot as plt - from fiesta.train.AfterglowData import AfterglowpyData ############# ### SETUP ### ############# -tmin = 0.1 # days +tmin = 1e-4 # days tmax = 2000 # days -n_times = 200 +n_times = 250 numin = 1e9 # Hz -numax = 2.5e18 # Hz (10 keV) +numax = 2.5e19 # Hz (100 keV) n_nu = 256 @@ -34,7 +32,7 @@ jet_conversion = {"tophat": -1, "gaussian": 0} -n_training = 45_000 +n_training = 20_000 n_val = 0 n_test = 0 diff --git a/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py index 74c2ca1..f3f3c14 100644 --- a/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py +++ b/flux_models/afterglowpy_gaussian/create_special_data_afterglowpy_gaussian.py @@ -1,25 +1,16 @@ +from jax.random import PRNGKey import numpy as np + from fiesta.train.AfterglowData import AfterglowpyData +from fiesta.inference.prior_dict import ConstrainedPrior +from fiesta.inference.prior import Uniform, Constraint ############# ### SETUP ### ############# -parameter_distributions = { - 'inclination_EM': (0, np.pi/2, "uniform"), - 'log10_E0': (47, 57, "uniform"), - 'thetaCore': (0.01, np.pi/5, "loguniform"), - 'alphaWing': (0.2, 3.5, "uniform"), - 'log10_n0': (-6, 2, "uniform"), - 'p': (2.01, 3, "uniform"), - 'log10_epsilon_e': (-4, 0, "uniform"), - 'log10_epsilon_B': (-8, 0, "uniform") -} - - - -name = "tophat" +name = "gaussian" outdir = f"./model/" n_training = 0 @@ -27,6 +18,7 @@ n_test = 0 n_pool = 24 + size = 20_000 @@ -39,26 +31,28 @@ n_test = 0, n_pool = n_pool) -#import h5py -#with h5py.File(creator.outfile, "r+") as f: -# unproblematic = np.unique(np.where(~np.isinf(f["special_train"]["01"]["y"]))[0]) -# -# X = f["special_train"]["01"]["X"][unproblematic] -# y = f["special_train"]["01"]["y"][unproblematic] -# breakpoint() -# creator._save_to_file(X, y, group = "special_train", label = "02", comment = "log10_E0 (54, 57) log10_n0 (-6, -4) thetaCore (0.4, np.pi/5)") - - - -inclination = np.random.uniform(0, np.pi/2, size = size) -log10_E0 = np.random.uniform(54, 57, size = size) -thetaCore = np.random.uniform(0.4, np.pi/5, size= size) -alphaWing = np.random.uniform(0.2, 3.5, size = size) -log10_n0 = np.random.uniform(-6, -4, size = size) -p = np.random.uniform(2, 3, size = size) -log10_epsilon_e = np.random.uniform(-4, 0, size = size) -log10_epsilon_B = np.random.uniform(-8, 0, size = size) - -X = np.array([inclination, log10_E0, thetaCore, alphaWing, log10_n0, p, log10_epsilon_e, log10_epsilon_B]).T - -creator.create_special_data(X, label = "01", comment = "log10_E0 (54, 57) log10_n0 (-6, -4) thetaCore (0.4, np.pi/5)") \ No newline at end of file +def conversion_function(sample): + converted_sample = sample + converted_sample["thetaWing"] = converted_sample["thetaCore"] * converted_sample["alphaWing"] + converted_sample["epsilon_tot"] = 10**(converted_sample["log10_epsilon_B"]) + 10**(converted_sample["log10_epsilon_e"]) + return converted_sample + +prior = ConstrainedPrior([ + Uniform(xmin=0., xmax=np.pi/2, naming=["inclination_EM"]), + Uniform(xmin=54., xmax=57., naming=["log10_E0"]), + Uniform(xmin=0.35, xmax=np.pi/5, naming=["thetaCore"]), + Uniform(0.2, 3.5, naming=["alphaWing"]), + Uniform(xmin=-6.,xmax=-4.,naming=["log10_n0"]), + Uniform(xmin=2., xmax=3., naming=["p"]), + Uniform(xmin=-4., xmax=0., naming=["log10_epsilon_e"]), + Uniform(xmin=-8.,xmax=0., naming=["log10_epsilon_B"]), + Constraint(xmin=0., xmax=1., naming=["epsilon_tot"]), + Constraint(xmin=0., xmax=np.pi/2, naming=["thetaWing"]) + ], + conversion_function) + +X = prior.sample(PRNGKey(2728), n_samples=size) +X = [X[p] for p in creator.parameter_names] +X = np.transpose(X) + +creator.create_special_data(X, label = "01", comment = "log10_E0 (54, 57) log10_n0 (-6, -4) thetaCore (0.35, np.pi/5)") \ No newline at end of file diff --git a/flux_models/afterglowpy_gaussian/model/gaussian.pkl b/flux_models/afterglowpy_gaussian/model/gaussian.pkl index 7293261..9ed8608 100644 Binary files a/flux_models/afterglowpy_gaussian/model/gaussian.pkl and b/flux_models/afterglowpy_gaussian/model/gaussian.pkl differ diff --git a/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl b/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl index a41a1f5..02aa398 100644 Binary files a/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl and b/flux_models/afterglowpy_gaussian/model/gaussian_metadata.pkl differ diff --git a/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py index 82a3371..e12e0de 100644 --- a/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py +++ b/flux_models/afterglowpy_gaussian/train_afterglowpy_gaussian.py @@ -2,10 +2,9 @@ import matplotlib.pyplot as plt import h5py -from fiesta.train.FluxTrainer import PCATrainer, DataManager -from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.train.FluxTrainer import PCATrainer +from fiesta.inference.lightcurve_model import AfterglowFlux from fiesta.train.neuralnets import NeuralnetConfig -from fiesta.utils import Filter ############# ### SETUP ### @@ -16,41 +15,38 @@ numin = 1e9 # Hz -numax = 2.5e18 +numax = 5e18 - -n_training = 70_000 -n_val = 5000 -n_pca = 100 +n_training = 80_000 +n_val = 7500 +n_pca = 200 name = "gaussian" outdir = f"./model/" file = outdir + "afterglowpy_raw_data.h5" config = NeuralnetConfig(output_size=n_pca, - nb_epochs=100_000, + nb_epochs=150_000, hidden_layer_sizes = [256, 512, 256], - learning_rate =8e-3) - + learning_rate =4e-3) ############### ### TRAINER ### ############### -data_manager = DataManager(file = file, - n_training= n_training, +data_manager_args = dict(file = file, + n_training= n_training, n_val= n_val, tmin= tmin, tmax= tmax, numin = numin, - numax = numax, - special_training=["02"]) + numax = numax, + special_training=["01"]) -data_manager.print_file_info() trainer = PCATrainer(name, outdir, - data_manager = data_manager, + data_manager_args = data_manager_args, plots_dir=f"./benchmarks/", n_pca = n_pca, save_preprocessed_data=False @@ -60,48 +56,51 @@ ### FITTING ### ############### - trainer.fit(config=config) trainer.save() ############# -### TEST ### +### TEST ### ############# print("Producing example lightcurve . . .") FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] -lc_model = AfterglowpyPCA(name, +lc_model = AfterglowFlux(name, outdir, - filters = FILTERS) + filters = FILTERS, + model_type = "MLP") -for filt in lc_model.Filters: - with h5py.File(file, "r") as f: - X_example = f["val"]["X"][-1] - y_raw = f["val"]["y"][-1, data_manager.mask] - y_raw = y_raw.reshape(256, len(lc_model.times)) - y_raw = np.exp(y_raw) - y_raw = np.array([np.interp(filt.nu, lc_model.metadata["nus"], column) for column in y_raw.T]) - y_raw = -48.6 + -1 * np.log10(y_raw*1e-3 / 1e23) * 2.5 - +with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, trainer.data_manager.mask] + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + mJys = np.exp(y_raw) + # Turn into a dict: this is how the model expects the input X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} # Get the prediction lightcurve - y_predict = lc_model.predict(X_example)[filt.name] + _, y_predict = lc_model.predict_abs_mag(X_example) + + + for filt in lc_model.Filters: + + y_val = filt.get_mag(mJys, lc_model.nus) + + plt.plot(lc_model.times, y_val, color = "red", label="afterglowpy") + plt.plot(lc_model.times, y_predict[filt.name], color = "blue", label="Surrogate prediction") + upper_bound = y_predict[filt.name] + 1 + lower_bound = y_predict[filt.name] - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.xlabel("$t$ in days") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) - plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") - plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") - upper_bound = y_predict + 1 - lower_bound = y_predict - 1 - plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) - - plt.ylabel(f"mag for {filt.name}") - plt.legend() - plt.gca().invert_yaxis() - plt.xscale('log') - plt.xlim(lc_model.times[0], lc_model.times[-1]) - - plt.savefig(f"./benchmarks/afterglowpy_{name}_{filt.name}_example.png") - plt.close() \ No newline at end of file + plt.savefig(f"./benchmarks/afterglowpy_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/afterglowpy_gaussian_cVAE/benchmark_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian_cVAE/benchmark_afterglowpy_gaussian.py new file mode 100644 index 0000000..7cfb684 --- /dev/null +++ b/flux_models/afterglowpy_gaussian_cVAE/benchmark_afterglowpy_gaussian.py @@ -0,0 +1,31 @@ +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowFlux + + +name = "gaussian" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + + +lc_model = AfterglowFlux(name, + model_dir, + filters = FILTERS, + model_type = "CVAE") + +for metric_name in ["L2", "Linf"]: + + benchmarker = Benchmarker( + model = lc_model, + data_file = "../afterglowpy_gaussian/model/afterglowpy_raw_data.h5", + metric_name = metric_name + ) + + benchmarker.benchmark() + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\\alpha_w$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) + + + + + + + diff --git a/flux_models/afterglowpy_gaussian_cVAE/model/gaussian.pkl b/flux_models/afterglowpy_gaussian_cVAE/model/gaussian.pkl new file mode 100644 index 0000000..872c6e8 Binary files /dev/null and b/flux_models/afterglowpy_gaussian_cVAE/model/gaussian.pkl differ diff --git a/flux_models/afterglowpy_gaussian_cVAE/model/gaussian_metadata.pkl b/flux_models/afterglowpy_gaussian_cVAE/model/gaussian_metadata.pkl new file mode 100644 index 0000000..45b3199 Binary files /dev/null and b/flux_models/afterglowpy_gaussian_cVAE/model/gaussian_metadata.pkl differ diff --git a/flux_models/afterglowpy_gaussian_cVAE/train_afterglowpy_gaussian.py b/flux_models/afterglowpy_gaussian_cVAE/train_afterglowpy_gaussian.py new file mode 100644 index 0000000..a6aa301 --- /dev/null +++ b/flux_models/afterglowpy_gaussian_cVAE/train_afterglowpy_gaussian.py @@ -0,0 +1,109 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.FluxTrainer import CVAETrainer, DataManager +from fiesta.inference.lightcurve_model import AfterglowFlux +from fiesta.train.neuralnets import NeuralnetConfig + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 + + +numin = 1e9 # Hz +numax = 5e18 + + +n_training = 80_000 +n_val = 7500 +image_size = np.array([32, 25]) + +name = "gaussian" +outdir = f"./model/" +file = "../afterglowpy_gaussian/model/afterglowpy_raw_data.h5" + +config = NeuralnetConfig(output_size= int(np.prod(image_size)), + nb_epochs=50_000, + hidden_layer_sizes = [600, 400, 200], + learning_rate =5e-4) + + +############### +### TRAINER ### +############### + + +data_manager_args = dict(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=[]) + +trainer = CVAETrainer(name, + outdir, + data_manager_args = data_manager_args, + plots_dir=f"./benchmarks/", + image_size=image_size, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowFlux(name, + outdir, + filters = FILTERS, + model_type = "CVAE") + + +with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, trainer.data_manager.mask] + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + mJys = np.exp(y_raw) + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + _, y_predict = lc_model.predict_abs_mag(X_example) + + + for filt in lc_model.Filters: + + y_val = filt.get_mag(mJys, lc_model.nus) + + plt.plot(lc_model.times, y_val, color = "red", label="afterglowpy") + plt.plot(lc_model.times, y_predict[filt.name], color = "blue", label="Surrogate prediction") + upper_bound = y_predict[filt.name] + 1 + lower_bound = y_predict[filt.name] - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.xlabel("$t$ in days") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/afterglowpy_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png deleted file mode 100644 index be608bd..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_X-ray-1keV_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png deleted file mode 100644 index 2a3fdcb..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_bessellv_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png deleted file mode 100644 index ffbd76a..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-3GHz_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png b/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png deleted file mode 100644 index f9fbacc..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/afterglowpy_tophat_radio-6GHz_example.png and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf deleted file mode 100644 index 986818e..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf deleted file mode 100644 index 975cb2f..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf deleted file mode 100644 index a5348c1..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_X-ray-1keV_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf deleted file mode 100644 index 833cec6..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf deleted file mode 100644 index 82cbf50..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf deleted file mode 100644 index 214a1ac..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_bessellv_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf deleted file mode 100644 index 4d7d3d9..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf deleted file mode 100644 index 8180d59..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf deleted file mode 100644 index 8d2eb66..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-3GHz_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf deleted file mode 100644 index 43f9395..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf deleted file mode 100644 index 89c23e9..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf deleted file mode 100644 index af74378..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/benchmark_radio-6GHz_Linf_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf deleted file mode 100644 index ec5718c..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/error_distribution.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf deleted file mode 100644 index 83213b9..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/error_over_time.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf b/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf deleted file mode 100644 index 7030bd8..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/error_over_time_before.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png b/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png deleted file mode 100644 index 36ea0fe..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/learning_curves_tophat.png and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf deleted file mode 100644 index 506ca89..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_L2.pdf and /dev/null differ diff --git a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf b/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf deleted file mode 100644 index 7cfd65b..0000000 Binary files a/flux_models/afterglowpy_tophat/benchmarks/worst_lightcurves_Linf.pdf and /dev/null differ diff --git a/flux_models/pyblastafterglow_gaussian/benchmark_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian/benchmark_pyblastafterglow_gaussian.py new file mode 100644 index 0000000..4cee8c5 --- /dev/null +++ b/flux_models/pyblastafterglow_gaussian/benchmark_pyblastafterglow_gaussian.py @@ -0,0 +1,38 @@ +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowFlux + + +name = "gaussian" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + + +lc_model = AfterglowFlux(name, + directory = model_dir, + filters = FILTERS, + model_type= "MLP") + +for metric_name in ["L2", "Linf"]: + + + benchmarker = Benchmarker( + model = lc_model, + data_file = "./model/pyblastafterglow_raw_data.h5", + metric_name = metric_name + ) + + benchmarker.benchmark() + + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\iota$", + "$\log_{10}(E_0)$", + "$\\theta_{\\mathrm{c}}$", + "$\\alpha_{\\mathrm{w}}$", + "$\log_{10}(n_{\mathrm{ism}})$", + "$p$", + "$\log_{10}(\\epsilon_e)$", + "$\log_{10}(\\epsilon_B)$", + "$\\Gamma_0$"]) + + + + diff --git a/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py index 95a5e76..65118d3 100644 --- a/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py +++ b/flux_models/pyblastafterglow_gaussian/create_pyblastafterglow_gaussian.py @@ -1,5 +1,4 @@ import numpy as np -import matplotlib.pyplot as plt from fiesta.train.AfterglowData import PyblastafterglowData from mpi4py import MPI @@ -21,7 +20,6 @@ numax = 2.5e19 # Hz (100 keV) n_nu = 256 - parameter_distributions = { 'inclination_EM': (0, np.pi/2, "uniform"), 'log10_E0': (47, 57, "uniform"), @@ -33,16 +31,16 @@ 'log10_epsilon_B': (-8,0, "uniform"), 'Gamma0': (100, 1000, "uniform") } - - jet_name = "gaussian" jet_conversion = {"tophat": -1, - "gaussian": 0} + "gaussian": 0, + "powerlaw": 4} + +n_training = 20 +n_val = 4 +n_test = 4 -n_training = 10 -n_val = 2 -n_test = 2 ####################### @@ -56,16 +54,16 @@ creator = PyblastafterglowData(outdir = outdir, - jet_type = jet_type, - n_training = n_training, - n_val = n_val, - n_test = n_test, - tmin = tmin, - tmax = tmax, - n_times = n_times, - numin = numin, - numax = numax, - n_nu = n_nu, - parameter_distributions = parameter_distributions, - rank = rank, - path_to_exec = "/home/aya/work/hkoehn/fiesta/PyBlastAfterglowMag/src/pba.out") \ No newline at end of file + jet_type = jet_type, + n_training = n_training, + n_val = n_val, + n_test = n_test, + tmin = tmin, + tmax = tmax, + n_times = n_times, + numin = numin, + numax = numax, + n_nu = n_nu, + parameter_distributions = parameter_distributions, + rank = rank, + path_to_exec = "/hppfs/scratch/06/di35kuf/pyblastafterglow/PyBlastAfterglowMag/src/pba.out") diff --git a/flux_models/pyblastafterglow_gaussian/model/gaussian.pkl b/flux_models/pyblastafterglow_gaussian/model/gaussian.pkl new file mode 100644 index 0000000..8117b2b Binary files /dev/null and b/flux_models/pyblastafterglow_gaussian/model/gaussian.pkl differ diff --git a/flux_models/pyblastafterglow_gaussian/model/gaussian_metadata.pkl b/flux_models/pyblastafterglow_gaussian/model/gaussian_metadata.pkl new file mode 100644 index 0000000..6eefe53 Binary files /dev/null and b/flux_models/pyblastafterglow_gaussian/model/gaussian_metadata.pkl differ diff --git a/flux_models/pyblastafterglow_gaussian/train_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian/train_pyblastafterglow_gaussian.py new file mode 100644 index 0000000..1077dbd --- /dev/null +++ b/flux_models/pyblastafterglow_gaussian/train_pyblastafterglow_gaussian.py @@ -0,0 +1,106 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.FluxTrainer import PCATrainer +from fiesta.inference.lightcurve_model import AfterglowFlux +from fiesta.train.neuralnets import NeuralnetConfig + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 + + +numin = 1e9 # Hz +numax = 5e18 + +n_training = 91670 +n_val = 7676 +n_pca = 200 + +name = "gaussian" +outdir = f"./model/" +file = outdir + "pyblastafterglow_raw_data.h5" + +config = NeuralnetConfig(output_size=n_pca, + nb_epochs=300_000, + hidden_layer_sizes = [256, 512, 256], + learning_rate =4e-3) + +############### +### TRAINER ### +############### + + +data_manager_args = dict(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=[]) + +trainer = PCATrainer(name, + outdir, + data_manager_args = data_manager_args, + plots_dir=f"./benchmarks/", + n_pca = n_pca, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowFlux(name, + outdir, + filters = FILTERS, + model_type = "MLP") + + +with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, trainer.data_manager.mask] + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + mJys = np.exp(y_raw) + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + _, y_predict = lc_model.predict_abs_mag(X_example) + + + for filt in lc_model.Filters: + + y_val = filt.get_mag(mJys, lc_model.nus) + + plt.plot(lc_model.times, y_val, color = "red", label="pyblastafterglow") + plt.plot(lc_model.times, y_predict[filt.name], color = "blue", label="Surrogate prediction") + upper_bound = y_predict[filt.name] + 1 + lower_bound = y_predict[filt.name] - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.xlabel("$t$ in days") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/pyblastafterglow_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/pyblastafterglow_gaussian_cVAE/benchmark_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian_cVAE/benchmark_pyblastafterglow_gaussian.py new file mode 100644 index 0000000..db686e7 --- /dev/null +++ b/flux_models/pyblastafterglow_gaussian_cVAE/benchmark_pyblastafterglow_gaussian.py @@ -0,0 +1,38 @@ +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowFlux + + +name = "gaussian" +model_dir = f"./model/" +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + + +lc_model = AfterglowFlux(name, + directory = model_dir, + filters = FILTERS, + model_type= "CVAE") + +for metric_name in ["L2", "Linf"]: + + + benchmarker = Benchmarker( + model = lc_model, + data_file = "../pyblastafterglow_gaussian/model/pyblastafterglow_raw_data.h5", + metric_name = metric_name + ) + + benchmarker.benchmark() + + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\iota$", + "$\log_{10}(E_0)$", + "$\\theta_{\\mathrm{c}}$", + "$\\alpha_{\\mathrm{w}}$", + "$\log_{10}(n_{\mathrm{ism}})$", + "$p$", + "$\log_{10}(\\epsilon_e)$", + "$\log_{10}(\\epsilon_B)$", + "$\\Gamma_0$"]) + + + + diff --git a/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian.pkl b/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian.pkl new file mode 100644 index 0000000..571094f Binary files /dev/null and b/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian.pkl differ diff --git a/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian_metadata.pkl b/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian_metadata.pkl new file mode 100644 index 0000000..cd77391 Binary files /dev/null and b/flux_models/pyblastafterglow_gaussian_cVAE/model/gaussian_metadata.pkl differ diff --git a/flux_models/pyblastafterglow_gaussian_cVAE/train_pyblastafterglow_gaussian.py b/flux_models/pyblastafterglow_gaussian_cVAE/train_pyblastafterglow_gaussian.py new file mode 100644 index 0000000..7ccbd91 --- /dev/null +++ b/flux_models/pyblastafterglow_gaussian_cVAE/train_pyblastafterglow_gaussian.py @@ -0,0 +1,106 @@ +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.FluxTrainer import CVAETrainer +from fiesta.inference.lightcurve_model import AfterglowFlux +from fiesta.train.neuralnets import NeuralnetConfig + +############# +### SETUP ### +############# + +tmin = 0.1 # days +tmax = 2000 + + +numin = 1e9 # Hz +numax = 5e18 + +n_training = 91670 +n_val = 7676 +image_size = np.array([32, 25]) + +name = "gaussian" +outdir = f"./model/" +file = "../pyblastafterglow_gaussian/model/pyblastafterglow_raw_data.h5" + +config = NeuralnetConfig(output_size= int(np.prod(image_size)), + nb_epochs=300_000, + hidden_layer_sizes = [600, 500, 400, 300, 200], + learning_rate =5e-4) + +############### +### TRAINER ### +############### + + +data_manager_args = dict(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + special_training=[]) + +trainer = CVAETrainer(name, + outdir, + data_manager_args = data_manager_args, + plots_dir=f"./benchmarks/", + image_size= image_size, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") +FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] + +lc_model = AfterglowFlux(name, + outdir, + filters = FILTERS, + model_type = "CVAE") + + +with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, trainer.data_manager.mask] + y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) + mJys = np.exp(y_raw) + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + _, y_predict = lc_model.predict_abs_mag(X_example) + + + for filt in lc_model.Filters: + + y_val = filt.get_mag(mJys, lc_model.nus) + + plt.plot(lc_model.times, y_val, color = "red", label="pyblastafterglow") + plt.plot(lc_model.times, y_predict[filt.name], color = "blue", label="Surrogate prediction") + upper_bound = y_predict[filt.name] + 1 + lower_bound = y_predict[filt.name] - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.xlabel("$t$ in days") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/pyblastafterglow_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py index 69f3af3..567805a 100644 --- a/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py +++ b/flux_models/pyblastafterglow_tophat/benchmark_pyblastafterglow_tophat.py @@ -1,45 +1,35 @@ -import numpy as np -import matplotlib.pyplot as plt +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import AfterglowFlux -from fiesta.train.BenchmarkerFluxes import Benchmarker -from fiesta.inference.lightcurve_model import AfterglowpyPCA -from fiesta.utils import Filter name = "tophat" model_dir = f"./model/" -FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv"] +FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] + +lc_model = AfterglowFlux(name, + directory = model_dir, + filters = FILTERS, + model_type= "MLP") + +for metric_name in ["L2", "Linf"]: -for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: - if metric_name == "$\\mathcal{L}_2$": - file_ending = "L2" - else: - file_ending = "Linf" - - B = Benchmarker(name = name, - model_dir = model_dir, - MODEL = AfterglowpyPCA, - filters = FILTERS, - metric_name = metric_name, - ) - - - for filt in FILTERS: - - fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{c}}$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$", "$\\Gamma_0$"]) - fig.savefig(f"./benchmarks/benchmark_{filt}_{file_ending}.pdf", dpi = 200) - B.print_correlations(filter = filt) + benchmarker = Benchmarker( + model = lc_model, + data_file = "./model/pyblastafterglow_raw_data.h5", + metric_name = metric_name + ) - fig, ax = B.plot_worst_lightcurves() - fig.savefig(f"./benchmarks/worst_lightcurves_{file_ending}.pdf", dpi = 200) - - -fig, ax = B.plot_error_distribution() -fig.savefig(f"./benchmarks/error_distribution.pdf", dpi = 200) - -fig, ax = B.plot_error_over_time() -fig.savefig(f"./benchmarks/error_over_time.pdf", dpi = 200) + benchmarker.benchmark() + + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\iota$", + "$\log_{10}(E_0)$", + "$\\theta_{\\mathrm{c}}$", + "$\log_{10}(n_{\mathrm{ism}})$", + "$p$", "$\\epsilon_E$", + "$\\epsilon_B$", + "$\\Gamma_0$"]) diff --git a/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py index f6d2a64..6ac6030 100644 --- a/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py +++ b/flux_models/pyblastafterglow_tophat/create_pyblastafterglow_tophat.py @@ -2,7 +2,7 @@ from fiesta.train.AfterglowData import PyblastafterglowData from mpi4py import MPI -comm = MPI.COMM_WORLD +comm = getattr(MPI, "COMM_WORLD") size = comm.Get_size() rank = comm.Get_rank() @@ -43,7 +43,6 @@ n_val = 10 n_test = 10 -retrain_weights = None ####################### diff --git a/flux_models/pyblastafterglow_tophat/model/tophat.pkl b/flux_models/pyblastafterglow_tophat/model/tophat.pkl new file mode 100644 index 0000000..d8f0281 Binary files /dev/null and b/flux_models/pyblastafterglow_tophat/model/tophat.pkl differ diff --git a/flux_models/pyblastafterglow_tophat/model/tophat_metadata.pkl b/flux_models/pyblastafterglow_tophat/model/tophat_metadata.pkl new file mode 100644 index 0000000..bd5afc9 Binary files /dev/null and b/flux_models/pyblastafterglow_tophat/model/tophat_metadata.pkl differ diff --git a/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py b/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py index 978c919..58729be 100644 --- a/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py +++ b/flux_models/pyblastafterglow_tophat/train_pyblastafterglow_tophat.py @@ -2,41 +2,40 @@ import matplotlib.pyplot as plt import h5py -from fiesta.train.FluxTrainer import PCATrainer, DataManager -from fiesta.inference.lightcurve_model import AfterglowpyPCA +from fiesta.train.FluxTrainer import PCATrainer +from fiesta.inference.lightcurve_model import AfterglowFlux from fiesta.train.neuralnets import NeuralnetConfig -from fiesta.utils import Filter ############# ### SETUP ### ############# -tmin = 1 # days +tmin = 0.1 # days tmax = 2000 numin = 1e9 # Hz -numax = 1e17 +numax = 2.5e18 -n_training = 50_000 -n_val = 5000 -n_pca = 100 +n_training = 57_600 +n_val = 5760 +n_pca = 200 name = "tophat" outdir = f"./model/" file = outdir + "pyblastafterglow_raw_data.h5" config = NeuralnetConfig(output_size=n_pca, - nb_epochs=100_000, - hidden_layer_sizes = [256, 512, 256], - learning_rate =8e-3) + nb_epochs=300_000, + hidden_layer_sizes = [300, 600, 300], + learning_rate =3e-3) ############### ### TRAINER ### ############### -data_manager = DataManager(file = file, +data_manager_args = dict(file = file, n_training= n_training, n_val= n_val, tmin= tmin, @@ -47,7 +46,7 @@ trainer = PCATrainer(name, outdir, - data_manager = data_manager, + data_manager_args = data_manager_args, plots_dir=f"./benchmarks/", n_pca = n_pca, save_preprocessed_data=False @@ -67,35 +66,41 @@ print("Producing example lightcurve . . .") FILTERS = ["radio-3GHz", "X-ray-1keV", "radio-6GHz", "bessellv"] -lc_model = AfterglowpyPCA(name, +lc_model = AfterglowFlux(name, outdir, - filters = FILTERS) + filters = FILTERS, + model_type = "MLP") -for filt in lc_model.Filters: - with h5py.File(file, "r") as f: - X_example = f["val"]["X"][-1] - y_raw = f["val"]["y"][-1, data_manager.mask] +with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-1] + y_raw = f["val"]["y"][-1, trainer.data_manager.mask] y_raw = y_raw.reshape(len(lc_model.nus), len(lc_model.times)) - y_raw = np.exp(y_raw) - y_raw = np.array([np.interp(filt.nu, lc_model.metadata["nus"], column) for column in y_raw.T]) - y_raw = -48.6 + -1 * np.log10(y_raw*1e-3 / 1e23) * 2.5 - + mJys = np.exp(y_raw) + # Turn into a dict: this is how the model expects the input X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} # Get the prediction lightcurve - y_predict = lc_model.predict(X_example)[filt.name] + y_predict = lc_model.predict(X_example) + + + for filt in lc_model.Filters: + + y_val = filt.get_mag(mJys, lc_model.nus) + + plt.plot(lc_model.times, y_val, color = "red", label="pyblastafterglow") + plt.plot(lc_model.times, y_predict[filt.name], color = "blue", label="Surrogate prediction") + upper_bound = y_predict[filt.name] + 1 + lower_bound = y_predict[filt.name] - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.ylabel(f"mag for {filt.name}") + plt.xlabel("$t$ in days") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) - plt.plot(lc_model.times, y_raw, color = "red", label="pyblast_afterglow") - plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") - upper_bound = y_predict + 1 - lower_bound = y_predict - 1 - plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) - - plt.ylabel(f"mag for {filt.name}") - plt.legend() - plt.gca().invert_yaxis() - - plt.savefig(f"./benchmarks/pyblastafterglow_{name}_{filt.name}_example.png") - plt.close() \ No newline at end of file + plt.savefig(f"./benchmarks/pyblastafterglow_{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/lightcurve_models/Bu2019/benchmark_Bu2019.py b/lightcurve_models/Bu2019/benchmark_Bu2019.py new file mode 100644 index 0000000..fcb0e70 --- /dev/null +++ b/lightcurve_models/Bu2019/benchmark_Bu2019.py @@ -0,0 +1,31 @@ +from fiesta.train.Benchmarker import Benchmarker +from fiesta.inference.lightcurve_model import BullaLightcurveModel + + + +name = "Bu2019" +model_dir = f"./model/" +FILTERS = ["2massj", "2massks", "sdssu", "ps1::r"] + +lc_model = BullaLightcurveModel(name, + model_dir, + filters = FILTERS) + +for metric_name in ["L2", "Linf"]: + + + benchmarker = Benchmarker( + model = lc_model, + data_file = "./model/Bu2019_raw_data.h5", + metric_name = metric_name + ) + + benchmarker.benchmark() + + benchmarker.plot_lightcurves_mismatch(parameter_labels = ["$\\log_{10}(m_{\\mathrm{ej, dyn}})$", "$\\log_{10}(m_{\\mathrm{ej, wind}})$", "$\\Phi_{\\mathrm{KN}}$", "$\\iota$"]) + + + + + + diff --git a/lightcurve_models/Bu2019/model/Bu2019_2massh.pkl b/lightcurve_models/Bu2019/model/Bu2019_2massh.pkl new file mode 100644 index 0000000..cee9d1b Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_2massh.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_2massj.pkl b/lightcurve_models/Bu2019/model/Bu2019_2massj.pkl new file mode 100644 index 0000000..cea41c4 Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_2massj.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_2massks.pkl b/lightcurve_models/Bu2019/model/Bu2019_2massks.pkl new file mode 100644 index 0000000..ad9b40d Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_2massks.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_metadata.pkl b/lightcurve_models/Bu2019/model/Bu2019_metadata.pkl new file mode 100644 index 0000000..a84a729 Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_metadata.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_ps1::g.pkl b/lightcurve_models/Bu2019/model/Bu2019_ps1::g.pkl new file mode 100644 index 0000000..d2336b3 Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_ps1::g.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_ps1::i.pkl b/lightcurve_models/Bu2019/model/Bu2019_ps1::i.pkl new file mode 100644 index 0000000..72f4b3c Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_ps1::i.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_ps1::r.pkl b/lightcurve_models/Bu2019/model/Bu2019_ps1::r.pkl new file mode 100644 index 0000000..78fcfaf Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_ps1::r.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_ps1::y.pkl b/lightcurve_models/Bu2019/model/Bu2019_ps1::y.pkl new file mode 100644 index 0000000..c6207e4 Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_ps1::y.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_ps1::z.pkl b/lightcurve_models/Bu2019/model/Bu2019_ps1::z.pkl new file mode 100644 index 0000000..d59856a Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_ps1::z.pkl differ diff --git a/lightcurve_models/Bu2019/model/Bu2019_sdssu.pkl b/lightcurve_models/Bu2019/model/Bu2019_sdssu.pkl new file mode 100644 index 0000000..23af652 Binary files /dev/null and b/lightcurve_models/Bu2019/model/Bu2019_sdssu.pkl differ diff --git a/lightcurve_models/Bu2019/train_Bu2019.py b/lightcurve_models/Bu2019/train_Bu2019.py new file mode 100644 index 0000000..5414381 --- /dev/null +++ b/lightcurve_models/Bu2019/train_Bu2019.py @@ -0,0 +1,108 @@ +from matplotlib.pylab import svd +import numpy as np +import matplotlib.pyplot as plt +import h5py + +from fiesta.train.LightcurveTrainer import SVDTrainer +from fiesta.inference.lightcurve_model import BullaLightcurveModel +from fiesta.train.neuralnets import NeuralnetConfig +from fiesta.utils import Filter + +############# +### SETUP ### +############# + +tmin = 1.5 # days +tmax = 20 + + +numin = 1e13 # Hz +numax = 6e15 + +n_training = 1276 +n_val = 160 + +svd_ncoeff = 50 +FILTERS = ["ps1::g", "ps1::r", "ps1::i", "ps1::z", "ps1::y", "2massj", "2massh", "2massks", "sdssu"] + +name = "Bu2019" +outdir = f"./model/" +file = "./model/Bu2019_raw_data.h5" + +config = NeuralnetConfig(output_size= svd_ncoeff, + nb_epochs=20_000, + hidden_layer_sizes = [128, 256, 128], + learning_rate =5e-3) + + +############### +### TRAINER ### +############### + + +data_manager_args = dict(file = file, + n_training= n_training, + n_val= n_val, + tmin= tmin, + tmax= tmax, + numin = numin, + numax = numax, + ) + +trainer = SVDTrainer(name, + outdir, + data_manager_args = data_manager_args, + plots_dir=f"./benchmarks/", + svd_ncoeff= svd_ncoeff, + filters= FILTERS, + save_preprocessed_data=False + ) + +############### +### FITTING ### +############### + + +trainer.fit(config=config) +trainer.save() + +############# +### TEST ### +############# + +print("Producing example lightcurve . . .") + +lc_model = BullaLightcurveModel(name, + outdir, + filters = FILTERS) + +for filt in lc_model.Filters: + with h5py.File(file, "r") as f: + X_example = f["val"]["X"][-2] + y_raw = f["val"]["y"][-2, trainer.data_manager.mask] + + y_raw = y_raw.reshape(len(trainer.nus), len(trainer.times)) + y_raw = np.exp(y_raw) + y_raw = filt.get_mag(y_raw, trainer.nus) + + # Turn into a dict: this is how the model expects the input + X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} + + # Get the prediction lightcurve + y_predict = lc_model.predict(X_example)[filt.name] + + plt.plot(lc_model.times, y_raw, color = "red", label="POSSIS") + plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") + upper_bound = y_predict + 1 + lower_bound = y_predict - 1 + plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) + + plt.xlabel(f"$t$ in days") + plt.ylabel(f"mag for {filt.name}") + plt.legend() + plt.gca().invert_yaxis() + plt.xscale('log') + plt.xlim(lc_model.times[0], lc_model.times[-1]) + + plt.savefig(f"./benchmarks/{name}_{filt.name}_example.png") + plt.close() \ No newline at end of file diff --git a/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl b/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl deleted file mode 100644 index 73a4949..0000000 Binary files a/lightcurve_models/GRB/afterglowpy/tophat/X-ray-1keV.pkl and /dev/null differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl b/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl deleted file mode 100644 index bc6b620..0000000 Binary files a/lightcurve_models/GRB/afterglowpy/tophat/bessellv.pkl and /dev/null differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl b/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl deleted file mode 100644 index ae9d474..0000000 Binary files a/lightcurve_models/GRB/afterglowpy/tophat/radio-3GHz.pkl and /dev/null differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl b/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl deleted file mode 100644 index 5d14206..0000000 Binary files a/lightcurve_models/GRB/afterglowpy/tophat/radio-6GHz.pkl and /dev/null differ diff --git a/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl b/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl deleted file mode 100644 index ca8e7dc..0000000 Binary files a/lightcurve_models/GRB/afterglowpy/tophat/tophat_metadata.pkl and /dev/null differ diff --git a/lightcurve_models/GRB/train_afterglowpy_tophat.py b/lightcurve_models/GRB/train_afterglowpy_tophat.py deleted file mode 100644 index 3ec2de3..0000000 --- a/lightcurve_models/GRB/train_afterglowpy_tophat.py +++ /dev/null @@ -1,149 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -from fiesta.train.SurrogateTrainer import AfterglowpyTrainer -from fiesta.train.Benchmarker import Benchmarker -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel -from fiesta.train.neuralnets import NeuralnetConfig -from fiesta.utils import Filter - -############# -### SETUP ### -############# - -FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] -for filter in FILTERS: - filter = Filter(filter) - print(filter.name, filter.nu) - -tmin = 1 -tmax = 1000 - -""" -#grid for radio-6GHz and radio-3GHz -parameter_grid = { - 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], - 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], - 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], - 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], - 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], - 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], - 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] -} - -#grid for X-ray-1keV and bessellv -parameter_grid = { - 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], - 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], - 'thetaCore': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], - 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], - 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], - 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], - 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] -} - - -""" - -FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] -FILTERS = ["radio-6GHz"] -parameter_grid = { - 'inclination_EM': np.linspace(0, np.pi/4, 12), - 'log10_E0': np.linspace(47, 56, 19), - 'thetaCore': np.logspace(-2, np.log10(np.pi/5), 12), - 'log10_n0': np.linspace(-6, 2, 17), - 'p': np.linspace(2.01, 3.0, 10), - 'log10_epsilon_e': np.linspace(-4, 0, 9), - 'log10_epsilon_B': np.linspace(-8, 0, 9) -} - - -jet_name = "tophat" -jet_conversion = {"tophat": -1, - "gaussian": 0, - "powerlaw": 4} - -name = "tophat" -outdir = f"./afterglowpy/{name}/" - -############### -### TRAINER ### -############### - -B = Benchmarker(name = name, - parameter_grid = parameter_grid, - model_dir = outdir, - filters = FILTERS, - n_test_data = 2000, - metric_name = "$\\mathcal{L}_\infty$", - remake_test_data = False, - jet_type = jet_conversion[jet_name], - ) - -ww = B.error_distribution - -weight_grids = ww["radio-6GHz"] - -# TODO: perhaps also want to train on the full LC, without the SVD? -# TODO: train to output flux, not the mag? -trainer = AfterglowpyTrainer(name, - outdir, - FILTERS, - parameter_grid, - weight_grids = weight_grids, - jet_type = jet_conversion[jet_name], - tmin = tmin, - tmax = tmax, - use_log_spacing = True, - plots_dir=f"./benchmarks/{name}", - svd_ncoeff=30, - save_raw_data=True, - save_preprocessed_data=True, - remake_training_data = True, - n_training_data = 7000 - ) - -############### -### FITTING ### -############### - -config = NeuralnetConfig(output_size=trainer.svd_ncoeff, - nb_epochs=50_000, - hidden_layer_sizes = [64, 128, 64], - learning_rate = 8e-3) - -trainer.fit(config=config) -trainer.save() - -############# -### TEST ### -############# - -print("Producing example lightcurve . . .") - -lc_model = AfterglowpyLightcurvemodel(name, - outdir, - filters = FILTERS) - -for filt in lc_model.filters: - X_example = trainer.val_X_raw[0] - y_raw = trainer.val_y_raw[filt][0] - - # Turn into a dict: this is how the model expects the input - X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} - - # Get the prediction lightcurve - y_predict = lc_model.predict(X_example)[filt] - - plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") - plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") - upper_bound = y_predict + 1 - lower_bound = y_predict - 1 - plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) - - plt.ylabel(f"mag for {filt}") - plt.legend() - plt.gca().invert_yaxis() - - plt.savefig(f"./benchmarks/{name}/afterglowpy_{name}_{filt}_example.png") - plt.close() \ No newline at end of file diff --git a/lightcurve_models/KN/Bu2019lm/.gitignore b/lightcurve_models/KN/Bu2019lm/.gitignore deleted file mode 100644 index 5fe0038..0000000 --- a/lightcurve_models/KN/Bu2019lm/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -true_lcs.npz -raw_data.npz - diff --git a/lightcurve_models/KN/Bu2019lm/2massh.pkl b/lightcurve_models/KN/Bu2019lm/2massh.pkl deleted file mode 100644 index 63075e6..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/2massh.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/2massj.pkl b/lightcurve_models/KN/Bu2019lm/2massj.pkl deleted file mode 100644 index 6e4cc53..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/2massj.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/2massks.pkl b/lightcurve_models/KN/Bu2019lm/2massks.pkl deleted file mode 100644 index 8275919..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/2massks.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/Bu2019lm_metadata.pkl b/lightcurve_models/KN/Bu2019lm/Bu2019lm_metadata.pkl deleted file mode 100644 index fb9c74e..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/Bu2019lm_metadata.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/ps1__g.pkl b/lightcurve_models/KN/Bu2019lm/ps1__g.pkl deleted file mode 100644 index e51df7e..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/ps1__g.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/ps1__i.pkl b/lightcurve_models/KN/Bu2019lm/ps1__i.pkl deleted file mode 100644 index 7a8af85..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/ps1__i.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/ps1__r.pkl b/lightcurve_models/KN/Bu2019lm/ps1__r.pkl deleted file mode 100644 index 6a8f83d..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/ps1__r.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/ps1__y.pkl b/lightcurve_models/KN/Bu2019lm/ps1__y.pkl deleted file mode 100644 index de951b0..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/ps1__y.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/ps1__z.pkl b/lightcurve_models/KN/Bu2019lm/ps1__z.pkl deleted file mode 100644 index 493139a..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/ps1__z.pkl and /dev/null differ diff --git a/lightcurve_models/KN/Bu2019lm/sdssu.pkl b/lightcurve_models/KN/Bu2019lm/sdssu.pkl deleted file mode 100644 index 37defd7..0000000 Binary files a/lightcurve_models/KN/Bu2019lm/sdssu.pkl and /dev/null differ diff --git a/lightcurve_models/KN/train_Bu2019lm.py b/lightcurve_models/KN/train_Bu2019lm.py deleted file mode 100644 index 9e7d428..0000000 --- a/lightcurve_models/KN/train_Bu2019lm.py +++ /dev/null @@ -1,103 +0,0 @@ -import os -import numpy as np -import matplotlib.pyplot as plt - -from fiesta.train.SurrogateTrainer import BullaSurrogateTrainer -from fiesta.inference.lightcurve_model import BullaLightcurveModel - -from fiesta.train.neuralnets import NeuralnetConfig - -print("Checking whether we found a GPU:") -import jax -print(jax.devices()) - -params = {"axes.grid": True, - "text.usetex" : True, - "font.family" : "serif", - "ytick.color" : "black", - "xtick.color" : "black", - "axes.labelcolor" : "black", - "axes.edgecolor" : "black", - "font.serif" : ["Computer Modern Serif"], - "xtick.labelsize": 16, - "ytick.labelsize": 16, - "axes.labelsize": 16, - "legend.fontsize": 16, - "legend.title_fontsize": 16, - "figure.titlesize": 16} - -plt.rcParams.update(params) - -# All filters that are in the files for this model: -FILTERS = ["ps1__g", "ps1__r", "ps1__i", "ps1__z", "ps1__y", "2massj", "2massh", "2massks", "sdssu"] - -# TODO: need to find a way to locate the files/help users -lc_dir = "/home/urash/twouters/projects/fiesta_dev/fiesta_test/lightcurves/Bu2019lm/lcs/" -name = "Bu2019lm" -outdir = f"./{name}/" -plots_dir = "./figures/" # to make plots - -if not os.path.exists(outdir): - os.makedirs(outdir) -if not os.path.exists(plots_dir): - os.makedirs(plots_dir) - -lc_files = os.listdir(lc_dir) -example_file = os.path.join(lc_dir, lc_files[0]) - -print("example_filename") -print(example_file) - -############### -### TRAINER ### -############### - -print("Defining trainer object, will take around 1 minute for loading and preprocessing") - -bulla_trainer = BullaSurrogateTrainer(name, - outdir, - filters = FILTERS, - data_dir=lc_dir, - tmin = 0.1, - tmax = 14.0, - dt = 0.1, - plots_dir = plots_dir, - save_raw_data = True, - save_preprocessed_data = True) - -print("Filters to train on:") -print(bulla_trainer.filters) - -full_lc_files = [os.path.join(lc_dir, f) for f in lc_files] -print("Example fetching parameters from filename:") -for filename in full_lc_files[:4]: - print("filename") - print(filename) - p = bulla_trainer.extract_parameters_function(filename) - print("Parameters extracted") - print(p) - -# Define the config if you want to change a default parameter -# Here we change the number of epochs to 10_000 -config = NeuralnetConfig(nb_epochs = 10_000, - output_size=bulla_trainer.svd_ncoeff) - -bulla_trainer.fit(config=config, verbose=True) -bulla_trainer.save() - -bulla_trainer._save_raw_data() -bulla_trainer._save_preprocessed_data() - -######################## -### LIGHTCURVE MODEL ### -######################## - -print("Producing example lightcurve . . .") - -lc_model = BullaLightcurveModel(name, - outdir, - FILTERS) - -times = bulla_trainer.times - -print("Training done!") \ No newline at end of file diff --git a/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl b/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl deleted file mode 100644 index 767b6f5..0000000 Binary files a/lightcurve_models/afterglowpy/gaussian/X-ray-1keV.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/gaussian/bessellv.pkl b/lightcurve_models/afterglowpy/gaussian/bessellv.pkl deleted file mode 100644 index 6e6b36a..0000000 Binary files a/lightcurve_models/afterglowpy/gaussian/bessellv.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl b/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl deleted file mode 100644 index 73e0bd1..0000000 Binary files a/lightcurve_models/afterglowpy/gaussian/gaussian_metadata.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl b/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl deleted file mode 100644 index 5e6f67b..0000000 Binary files a/lightcurve_models/afterglowpy/gaussian/radio-3GHz.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl b/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl deleted file mode 100644 index 6058076..0000000 Binary files a/lightcurve_models/afterglowpy/gaussian/radio-6GHz.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl b/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl deleted file mode 100644 index 73a4949..0000000 Binary files a/lightcurve_models/afterglowpy/tophat/X-ray-1keV.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/tophat/bessellv.pkl b/lightcurve_models/afterglowpy/tophat/bessellv.pkl deleted file mode 100644 index bc6b620..0000000 Binary files a/lightcurve_models/afterglowpy/tophat/bessellv.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl b/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl deleted file mode 100644 index ae9d474..0000000 Binary files a/lightcurve_models/afterglowpy/tophat/radio-3GHz.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl b/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl deleted file mode 100644 index 5d14206..0000000 Binary files a/lightcurve_models/afterglowpy/tophat/radio-6GHz.pkl and /dev/null differ diff --git a/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl b/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl deleted file mode 100644 index ca8e7dc..0000000 Binary files a/lightcurve_models/afterglowpy/tophat/tophat_metadata.pkl and /dev/null differ diff --git a/lightcurve_models/benchmark_afterglowpy_gaussian.py b/lightcurve_models/benchmark_afterglowpy_gaussian.py deleted file mode 100644 index 14b2a74..0000000 --- a/lightcurve_models/benchmark_afterglowpy_gaussian.py +++ /dev/null @@ -1,64 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -from fiesta.train.Benchmarker import Benchmarker -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel -from fiesta.utils import Filter - - -name = "gaussian" -model_dir = f"./afterglowpy/{name}/" -FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] - -parameter_grid = { - 'inclination_EM': np.linspace(0, np.pi/4, 12), - 'log10_E0': np.linspace(47, 56, 19), - 'thetaWing': np.linspace(0.01, np.pi/5, 12), - 'xCore': np.linspace(0.05, 1, 20), - 'log10_n0': np.linspace(-6, 2, 17), - 'p': np.linspace(2.01, 3.0, 10), - 'log10_epsilon_e': np.linspace(-4, 0, 9), - 'log10_epsilon_B': np.linspace(-8, 0, 9) -} - -for metric_name in ["$\\mathcal{L}_2$", "$\\mathcal{L}_\infty$"]: - if metric_name == "$\\mathcal{L}_2$": - file_ending = "L2" - else: - file_ending = "Linf" - - B = Benchmarker(name = name, - parameter_grid = parameter_grid, - model_dir = model_dir, - MODEL = AfterglowpyLightcurvemodel, - filters = FILTERS, - n_test_data = 2000, - metric_name = metric_name, - remake_test_data = False, - jet_type = 0, - ) - - fig, ax = B.plot_error_distribution("radio-6GHz") - - - for filt in FILTERS: - - fig, ax = B.plot_lightcurves_mismatch(filter =filt, parameter_labels = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_{\\mathrm{w}}$", "$x_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"]) - fig.savefig(f"./benchmarks/{name}/benchmark_{filt}_{file_ending}.pdf", dpi = 200) - - B.print_correlations(filter = filt) - - - if metric_name == "$\\mathcal{L}_\infty$": - fig, ax = B.plot_error_distribution(filt) - fig.savefig(f"./benchmarks/{name}/error_distribution_{filt}.pdf", dpi = 200) - - - fig, ax = B.plot_worst_lightcurves() - fig.savefig(f"./benchmarks/{name}/worst_lightcurves_{file_ending}.pdf", dpi = 200) - - - - - - diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png deleted file mode 100644 index c21c817..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_X-ray-1keV_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png deleted file mode 100644 index 306d201..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_bessellv_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png deleted file mode 100644 index c9fc842..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-3GHz_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png b/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png deleted file mode 100644 index b5cd9b8..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/afterglowpy_gaussian_radio-6GHz_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf deleted file mode 100644 index 48a6150..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf deleted file mode 100644 index 03d5c63..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf deleted file mode 100644 index 3d29635..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_X-ray-1keV_Linf_before.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf deleted file mode 100644 index 85eba89..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf deleted file mode 100644 index 380e869..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_bessellv_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf deleted file mode 100644 index 51b7614..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf deleted file mode 100644 index 0b859a3..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf deleted file mode 100644 index 3723121..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-3GHz_Linf_before.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf deleted file mode 100644 index 1f15167..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf deleted file mode 100644 index 26e7d08..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf b/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf deleted file mode 100644 index 5be4ef7..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/benchmark_radio-6GHz_Linf_before.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/correlations_before.txt b/lightcurve_models/benchmarks/gaussian/correlations_before.txt deleted file mode 100644 index 9566b17..0000000 --- a/lightcurve_models/benchmarks/gaussian/correlations_before.txt +++ /dev/null @@ -1,108 +0,0 @@ -Correlations for filter radio-3GHz: - -inclination_EM: 0.13648909886212593 -log10_E0: 0.33645145631212375 -thetaWing: 0.024265589755934246 -xCore: -0.0998657462258901 -log10_n0: -0.18244632046614845 -p: -0.12178999832348425 -log10_epsilon_e: 0.028199807808143303 -log10_epsilon_B: 0.11382668130975754 - - - -Correlations for filter radio-6GHz: - -inclination_EM: 0.07900887361655033 -log10_E0: 0.20617615787512814 -thetaWing: 0.0061923518520167874 -xCore: -0.1609517797089711 -log10_n0: -0.17330437332816612 -p: 0.06921641433003754 -log10_epsilon_e: 0.0010812683989245774 -log10_epsilon_B: 0.03164077022449176 - - - -Correlations for filter bessellv: - -inclination_EM: 0.10876851455212108 -log10_E0: 0.31350345293188253 -thetaWing: 0.06548400328241566 -xCore: -0.12897435382322575 -log10_n0: -0.24090642768022807 -p: -0.15163486371234677 -log10_epsilon_e: 0.051617546936042646 -log10_epsilon_B: 0.06200398370473675 - - - -Correlations for filter X-ray-1keV: - -inclination_EM: 0.0937755880248725 -log10_E0: 0.20646061403424484 -thetaWing: 0.029171682247324363 -xCore: -0.13630793956135479 -log10_n0: -0.11379719745935056 -p: -0.19303802440344964 -log10_epsilon_e: 0.04816653484776429 -log10_epsilon_B: 0.031864402035493226 - - - - - -Loaded SurrogateLightcurveModel with filters ['radio-3GHz', 'radio-6GHz', 'bessellv', 'X-ray-1keV'] - - - -Correlations for filter radio-3GHz: - -inclination_EM: 0.22729615522584115 -log10_E0: 0.2854912376107409 -thetaWing: -0.0655575192504771 -xCore: -0.1511753578064876 -log10_n0: -0.14326622677000614 -p: 0.003718908215404851 -log10_epsilon_e: 0.09145017016387194 -log10_epsilon_B: 0.13900251196367636 - - - -Correlations for filter radio-6GHz: - -inclination_EM: 0.15482259690303216 -log10_E0: 0.31706796584780933 -thetaWing: -0.04603719649371961 -xCore: -0.20588033447258053 -log10_n0: -0.19114656242441438 -p: 0.0195090896910469 -log10_epsilon_e: 0.08416371935468658 -log10_epsilon_B: 0.10182239506588713 - - - -Correlations for filter bessellv: - -inclination_EM: 0.131150012020771 -log10_E0: 0.26739763810111317 -thetaWing: -0.03328134154282784 -xCore: -0.20478505987062182 -log10_n0: -0.2029206212123128 -p: -0.053992018841870645 -log10_epsilon_e: 0.014003969294619653 -log10_epsilon_B: 0.0470155844120538 - - - -Correlations for filter X-ray-1keV: - -inclination_EM: 0.19483874480498534 -log10_E0: 0.23606543897286883 -thetaWing: -0.04549081935294193 -xCore: -0.2464464697288298 -log10_n0: -0.1599685189157919 -p: -0.10176554058762216 -log10_epsilon_e: 0.008516304543409538 -log10_epsilon_B: 0.05139817839737394 - diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf deleted file mode 100644 index 4f4b5f7..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/error_distribution_X-ray-1keV.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf deleted file mode 100644 index 95fd2d6..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/error_distribution_bessellv.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf deleted file mode 100644 index 5c373b0..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-3GHz.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf b/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf deleted file mode 100644 index dd7d55b..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/error_distribution_radio-6GHz.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png b/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png deleted file mode 100644 index a15ecd3..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/learning_curves_X-ray-1keV.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png b/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png deleted file mode 100644 index 0c216c4..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/learning_curves_bessellv.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png deleted file mode 100644 index 861db19..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-3GHz.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png b/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png deleted file mode 100644 index c9a5622..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/learning_curves_radio-6GHz.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf deleted file mode 100644 index 2974fc1..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf b/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf deleted file mode 100644 index 0cee2b7..0000000 Binary files a/lightcurve_models/benchmarks/gaussian/worst_lightcurves_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png deleted file mode 100644 index d7ee02f..0000000 Binary files a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_X-ray-1keV_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png deleted file mode 100644 index 3d72a42..0000000 Binary files a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_bessellv_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png deleted file mode 100644 index 15cb698..0000000 Binary files a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-3GHz_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png b/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png deleted file mode 100644 index 443366e..0000000 Binary files a/lightcurve_models/benchmarks/tophat/afterglowpy_tophat_radio-6GHz_example.png and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf deleted file mode 100644 index b73ccc7..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf deleted file mode 100644 index bcbfffb..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_X-ray-1keV_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf deleted file mode 100644 index 6b44e71..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf deleted file mode 100644 index 6101d70..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf b/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf deleted file mode 100644 index 8835d68..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_bessellv_Linf_before.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf deleted file mode 100644 index 65edfbf..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf deleted file mode 100644 index 35f3c78..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_radio-3GHz_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf deleted file mode 100644 index 4feced3..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf deleted file mode 100644 index 3a25d99..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf b/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf deleted file mode 100644 index a2caf84..0000000 Binary files a/lightcurve_models/benchmarks/tophat/benchmark_radio-6GHz_Linf_before.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf b/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf deleted file mode 100644 index 50b896f..0000000 Binary files a/lightcurve_models/benchmarks/tophat/worst_lightcurves_L2.pdf and /dev/null differ diff --git a/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf b/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf deleted file mode 100644 index 9a07ac8..0000000 Binary files a/lightcurve_models/benchmarks/tophat/worst_lightcurves_Linf.pdf and /dev/null differ diff --git a/lightcurve_models/train_afterglowpy_gaussian.py b/lightcurve_models/train_afterglowpy_gaussian.py deleted file mode 100644 index 4d7c984..0000000 --- a/lightcurve_models/train_afterglowpy_gaussian.py +++ /dev/null @@ -1,146 +0,0 @@ -import numpy as np -import matplotlib.pyplot as plt - -from fiesta.train.SurrogateTrainer import AfterglowpyTrainer -from fiesta.train.Benchmarker import Benchmarker -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel -from fiesta.train.neuralnets import NeuralnetConfig -from fiesta.utils import Filter - -############# -### SETUP ### -############# - -FILTERS = ["X-ray-1keV", "radio-6GHz", "radio-3GHz", "bessellv"] -for filter in FILTERS: - filter = Filter(filter) - print(filter.name, filter.nu) - -tmin = 1 -tmax = 1000 -""" -FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] -parameter_grid = { - 'inclination_EM': [0.0, np.pi/24, np.pi/12, np.pi/8, np.pi/6, np.pi*5/24, np.pi/4, np.pi/3, 5*np.pi/12, 1.4, np.pi/2], - 'log10_E0': [46.0, 46.5, 48, 50, 51, 52., 53, 53.5, 54., 54.5, 55.], - 'thetaWing': [0.01, 0.02, 0.03, 0.04, 0.05, 0.06, 0.1, 0.2, 0.3, np.pi/10], - 'xCore': [0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0], - 'log10_n0': [-7.0, -6.5, -6.0, -5.0, -4.0, -3.0, -1.0, 1.0], - 'p': [2.01, 2.1, 2.2, 2.4, 2.6, 2.8, 2.9, 3.0], - 'log10_epsilon_e': [-4, -3.5, -3, -2, -1, -0.66, -0.33, 0], - 'log10_epsilon_B': [-8, -6, -4, -2., -1., 0] -} -""" - -FILTERS = ["radio-3GHz", "radio-6GHz", "bessellv", "X-ray-1keV"] -parameter_grid = { - 'inclination_EM': np.linspace(0, np.pi/4, 12), - 'log10_E0': np.linspace(47, 56, 19), - 'thetaWing': np.logspace(-2, np.log10(np.pi/5), 12), - 'xCore': np.linspace(0.05, 1, 20), - 'log10_n0': np.linspace(-6, 2, 17), - 'p': np.linspace(2.01, 3.0, 10), - 'log10_epsilon_e': np.linspace(-4, 0, 9), - 'log10_epsilon_B': np.linspace(-8, 0, 9) -} - - - - - -jet_name = "gaussian" -jet_conversion = {"tophat": -1, - "gaussian": 0, - "powerlaw": 4} - -name = "gaussian" -outdir = f"./afterglowpy/{name}/" - -############### -### TRAINER ### -############### - - -# Benchmarker to get weights for the training data - -B = Benchmarker(name = name, - parameter_grid = parameter_grid, - model_dir = outdir, - filters = FILTERS, - n_test_data = 2000, - metric_name = "$\\mathcal{L}_\infty$", - remake_test_data = False, - jet_type = jet_conversion[jet_name], - ) - -ww = B.error_distribution - - -weight_grids = ww["X-ray-1keV"] - -#for p in parameter_grid.keys(): -# weight_grids[p] = np.average([ww[filt][p] for filt in FILTERS], axis = 0) - -# TODO: perhaps also want to train on the full LC, without the SVD? -# TODO: train to output flux, not the mag? -trainer = AfterglowpyTrainer(name, - outdir, - FILTERS, - parameter_grid, - weight_grids = weight_grids, - jet_type = jet_conversion[jet_name], - tmin = tmin, - tmax = tmax, - use_log_spacing = True, - plots_dir=f"./benchmarks/{name}", - svd_ncoeff=30, - save_raw_data=False, - save_preprocessed_data=False, - remake_training_data = False, - n_training_data = 10_000 - ) - -############### -### FITTING ### -############### - -config = NeuralnetConfig(output_size=trainer.svd_ncoeff, - nb_epochs=50_000, - hidden_layer_sizes = [64, 128, 64], - learning_rate =8e-3) - -trainer.fit(config=config) -trainer.save() - -############# -### TEST ### -############# - -print("Producing example lightcurve . . .") - -lc_model = AfterglowpyLightcurvemodel(name, - outdir, - filters = FILTERS) - -for filt in lc_model.filters: - X_example = trainer.val_X_raw[0] - y_raw = trainer.val_y_raw[filt][0] - - # Turn into a dict: this is how the model expects the input - X_example = {k: v for k, v in zip(lc_model.parameter_names, X_example)} - - # Get the prediction lightcurve - y_predict = lc_model.predict(X_example)[filt] - - plt.plot(lc_model.times, y_raw, color = "red", label="afterglowpy") - plt.plot(lc_model.times, y_predict, color = "blue", label="Surrogate prediction") - upper_bound = y_predict + 1 - lower_bound = y_predict - 1 - plt.fill_between(lc_model.times, lower_bound, upper_bound, color='blue', alpha=0.2) - - plt.ylabel(f"mag for {filt}") - plt.legend() - plt.gca().invert_yaxis() - - plt.savefig(f"./benchmarks/{name}/afterglowpy_{name}_{filt}_example.png") - plt.close() \ No newline at end of file diff --git a/src/.gitignore b/src/.gitignore deleted file mode 100644 index 4edf511..0000000 --- a/src/.gitignore +++ /dev/null @@ -1 +0,0 @@ -fiesta.egg-info diff --git a/src/fiesta.egg-info/PKG-INFO b/src/fiesta.egg-info/PKG-INFO deleted file mode 100644 index da026b0..0000000 --- a/src/fiesta.egg-info/PKG-INFO +++ /dev/null @@ -1,58 +0,0 @@ -Metadata-Version: 2.1 -Name: fiesta -Version: 0.0.1 -Summary: Fast inference of electromagnetic signals with JAX -Home-page: https://github.com/thibeauwouters/fiesta -Author: Thibeau Wouters -Author-email: thibeauwouters@gmail.com -License: MIT -Keywords: sampling,inference,astrophysics,kilonovae,gamma-ray bursts -Requires-Python: >=3.10 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: jax>=0.4.24 -Requires-Dist: jaxlib>=0.4.24 -Requires-Dist: numpy<2.0.0 -Requires-Dist: pandas<2.0.0 -Requires-Dist: jaxtyping -Requires-Dist: beartype -Requires-Dist: tqdm -Requires-Dist: scipy<=1.14.0 -Requires-Dist: ml_collections -Requires-Dist: astropy -Requires-Dist: sncosmo -Requires-Dist: flowMC -Requires-Dist: joblib - -# fiesta 🎉 - -`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x - -![fiesta logo](docs/fiesta_logo.jpeg) - -**NOTE:** `fiesta` is currently under development -- stay tuned! - -## Installation - -pip installation is currently work in progress. Install from source by cloning this Github repository and running -``` -pip install -e . -``` - -NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`). - -## Training surrogate models - -To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models` - -- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model. -- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure. - -## Examples - -- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model. -- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon. - -## Acknowledgements - -The logo was created by [ideogram AI](https://ideogram.ai/). diff --git a/src/fiesta.egg-info/SOURCES.txt b/src/fiesta.egg-info/SOURCES.txt deleted file mode 100644 index ce22b23..0000000 --- a/src/fiesta.egg-info/SOURCES.txt +++ /dev/null @@ -1,9 +0,0 @@ -LICENSE -README.md -pyproject.toml -setup.cfg -src/fiesta.egg-info/PKG-INFO -src/fiesta.egg-info/SOURCES.txt -src/fiesta.egg-info/dependency_links.txt -src/fiesta.egg-info/requires.txt -src/fiesta.egg-info/top_level.txt \ No newline at end of file diff --git a/src/fiesta.egg-info/dependency_links.txt b/src/fiesta.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/fiesta.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/fiesta.egg-info/requires.txt b/src/fiesta.egg-info/requires.txt deleted file mode 100644 index 283afba..0000000 --- a/src/fiesta.egg-info/requires.txt +++ /dev/null @@ -1,13 +0,0 @@ -jax>=0.4.24 -jaxlib>=0.4.24 -numpy<2.0.0 -pandas<2.0.0 -jaxtyping -beartype -tqdm -scipy<=1.14.0 -ml_collections -astropy -sncosmo -flowMC -joblib diff --git a/src/fiesta.egg-info/top_level.txt b/src/fiesta.egg-info/top_level.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/fiesta.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/fiesta/__init__.py b/src/fiesta/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fiesta/constants.py b/src/fiesta/constants.py index 304256e..e575db9 100644 --- a/src/fiesta/constants.py +++ b/src/fiesta/constants.py @@ -2,4 +2,7 @@ c = 299792458.0 # speed of light in vacuum, m/s pc_to_cm = 3.086e18 # parsec to cm -days_to_seconds = 86400.0 # days to seconds \ No newline at end of file +days_to_seconds = 86400.0 # days to seconds +h = 6.62607015e-34 # J s +h_erg_s = 6.6261e-27 # cm^2 g s^{-1} (i.e. erg s) +eV = 1.602176634e-19 # J \ No newline at end of file diff --git a/src/fiesta/conversions.py b/src/fiesta/conversions.py index 230ce09..275db20 100644 --- a/src/fiesta/conversions.py +++ b/src/fiesta/conversions.py @@ -1,23 +1,158 @@ -from fiesta.constants import pc_to_cm import jax import jax.numpy as jnp from jaxtyping import Array, Float import numpy as np +from fiesta.constants import pc_to_cm, h_erg_s, c + + +####################### +# DISTANCE CONVERSION # +####################### + def Mpc_to_cm(d: float): return d * 1e6 * pc_to_cm -# TODO: need a np and jnp version? -# TODO: account for extinction -def mJys_to_mag_np(mJys: np.array): - Jys = 1e-3 * mJys - mag = -48.6 + -1 * np.log10(Jys / 1e23) * 2.5 +################### +# FLUX CONVERSION # +################### + +def Flambda_to_Fnu(F_lambda: Float[Array, "n_lambdas n_times"], lambdas: Float[Array, "n_lambdas"]) -> Float[Array, "n_lambdas n_times"]: + """ + JAX-compatible conversion of wavelength flux in erg cm^{-2} s^{-1} Angström^{-1} to spectral flux density in mJys. + + Args: + flux_lambda (Float[Array]): 2D flux density array in erg cm^{-2} s^{-1} Angström^{-1}. The rows correspond to the wavelengths provided in lambdas. + lambdas (Float[Array]): 1D wavelength array in Angström. + Returns: + mJys (Float[Array]): 2D spectral flux density array in mJys + nus (Float[Array]): 1D frequency array in Hz + """ + F_lambda = F_lambda.reshape(lambdas.shape[0], -1) + log_F_lambda = jnp.log10(F_lambda) # got to log because of large factors + log_F_nu = log_F_lambda + 2* jnp.log10(lambdas[:, None]) + jnp.log10(3.3356) + 4 # https://en.wikipedia.org/wiki/AB_magnitude + F_nu = 10**(log_F_nu) + F_nu = F_nu[::-1, :] # reverse the order to get lowest frequencies in first row + mJys = 1e3 * F_nu # convert Jys to mJys + + nus = c / (lambdas*1e-10) + nus = nus[::-1] + + return mJys, nus + +def Fnu_to_Flambda(F_nu: Float[Array, "n_nus n_times"], nus: Float[Array, "n_nus"]) -> Float[Array, "n_nus n_times"]: + """ + JAX-compatible conversion of spectral flux density in mJys to wavelength flux in erg cm^{-2} s^{-1}. + + Args: + flux_nu (Float[Array]): 2D flux density array in mJys. The rows correspond to the frequencies provided in nus. + nus (Float[Array]): 1D frequency array in Hz. + Returns: + flux_lambda (Float[Array]): 2D wavelength flux density array in erg cm^{-2} s^{-1} Angström^{-1}. + lambdas (Float[Array]): 1D wavelength array in Angström. + """ + F_nu = F_nu.reshape(nus.shape[0], -1) + log_F_nu = jnp.log10(F_nu) # go to log because of large factors + log_F_nu = log_F_nu - 3 # convert mJys to Jys + log_F_lambda = log_F_nu + 2 * jnp.log10(nus[:, None]) + jnp.log10(3.3356) - 42 + F_lambda = 10**(log_F_lambda) + F_lambda = F_lambda[::-1, :] # reverse the order to get the lowest wavelegnths in first row + + lambdas = c / nus + lambdas = lambdas[::-1] * 1e10 + + return F_lambda, lambdas + +def apply_redshift(F_nu: Float[Array, "n_nus n_times"], times: Float[Array, "n_times"], nus: Float[Array, "n_nus"], z: Float): + + F_nu = F_nu * (1 + z) # this is just the frequency redshift, cosmological energy loss and time elongation are taken into account by luminosity_distance + times = times * (1 + z) + nus = nus / (1 + z) + + return F_nu, times, nus + +######################## +# MAGNITUDE CONVERSION # +######################## + +def monochromatic_AB_mag(flux: Float[Array, "n_nus n_times"], + nus: Float[Array, "n_nus"], + nus_filt: Float[Array, "n_nus_filt"], + trans_filt: Float[Array, "n_nus_filt"], + ref_flux: Float) -> Float[Array, "n_times"]: + + interp_col = lambda col: jnp.interp(nus_filt, nus, col) + mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array + + mJys = mJys * trans_filt[:, None] + mag = mJys_to_mag_jnp(mJys) + return mag[0] + +def bandpass_AB_mag(flux: Float[Array, "n_nus n_times"], + nus: Float[Array, "n_nus"], + nus_filt: Float[Array, "n_nus_filt"], + trans_filt: Float[Array, "n_nus_filt"], + ref_flux: Float) -> Float[Array, "n_times"]: + """ + This is a JAX-compatile equivalent of sncosmo.TimeSeriesSource.bandmag(). Unlike sncosmo, we use the frequency flux and not wavelength flux, + but this function is tested to yield the same results as the sncosmo version. + + Args: + flux (Float[Array, "n_nus n_times"]): Spectral flux density as a 2D array in mJys. + nus (Float[Array, "n_nus"]): Associated frequencies in Hz + nus_filt (Float[Array, "n_nus_filt"]): frequency array of the filter in Hz + trans_filt (Float[Array, "n_nus_filt"]): transmissivity array of the filter in transmitted photons / incoming photons + ref_flux (Float): flux in mJy for which the filter is 0 mag + """ + + interp_col = lambda col: jnp.interp(nus_filt, nus, col) + mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array + + log_mJys = jnp.log10(mJys) # go to log because of large factors + log_mJys = log_mJys + jnp.log10(trans_filt[:, None]) + log_mJys = log_mJys - jnp.log10(h_erg_s) - jnp.log10(nus_filt[:, None]) # https://en.wikipedia.org/wiki/AB_magnitude + + max_log_mJys = jnp.max(log_mJys) + integrand = 10**(log_mJys - max_log_mJys) # make the integrand between 0 and 1, otherwise infs could appear + integrate_col = lambda col: jnp.trapezoid(y = col, x = nus_filt) + norm_band_flux = jax.vmap(integrate_col, in_axes = 1)(integrand) # normalized band flux + + log_integrated_flux = jnp.log10(norm_band_flux) + max_log_mJys # reintroduce scale here + mag = -2.5 * log_integrated_flux + 2.5 * jnp.log10(ref_flux) + return mag + +def integrated_AB_mag(flux: Float[Array, "n_nus n_times"], + nus: Float[Array, "n_nus"], + nus_filt: Float[Array, "n_nus_filt"], + trans_filt: Float[Array, "n_nus_filt"]) -> Float[Array, "n_times"]: + + interp_col = lambda col: jnp.interp(nus_filt, nus, col) + mJys = jax.vmap(interp_col, in_axes = 1, out_axes = 1)(flux) # apply vectorized interpolation to interpolate columns of 2D array + + log_mJys = jnp.log10(mJys) # go to log because of large factors + log_mJys = log_mJys + jnp.log10(trans_filt[:, None]) + + max_log_mJys = jnp.max(log_mJys) + integrand = 10**(log_mJys - max_log_mJys) # make the integrand between 0 and 1, otherwise infs could appear + integrate_col = lambda col: jnp.trapezoid(y = col, x = nus_filt) + norm_band_flux = jax.vmap(integrate_col, in_axes = 1)(integrand) # normalized band flux + + log_integrated_flux = jnp.log10(norm_band_flux) + max_log_mJys # reintroduce scale here + log_integrated_flux = log_integrated_flux - jnp.log10(nus_filt[-1] - nus_filt[0]) # divide by integration range + mJys = 10**log_integrated_flux + mag = mJys_to_mag_jnp(mJys) return mag @jax.jit def mJys_to_mag_jnp(mJys: Array): + mag = -48.6 + -1 * jnp.log10(mJys) * 2.5 + 26 * 2.5 # https://en.wikipedia.org/wiki/AB_magnitude + return mag + +# TODO: need a np and jnp version? +# TODO: account for extinction +def mJys_to_mag_np(mJys: np.array): Jys = 1e-3 * mJys - mag = -48.6 + -1 * jnp.log10(Jys / 1e23) * 2.5 + mag = -48.6 + -1 * np.log10(Jys / 1e23) * 2.5 return mag def mag_app_from_mag_abs(mag_abs: Array, diff --git a/src/fiesta/inference/__init__.py b/src/fiesta/inference/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fiesta/inference/fiesta.py b/src/fiesta/inference/fiesta.py index e55b1c0..630cf5a 100644 --- a/src/fiesta/inference/fiesta.py +++ b/src/fiesta/inference/fiesta.py @@ -189,6 +189,27 @@ def get_samples(self, training: bool = False) -> dict: chains = self.prior.transform(self.prior.add_name(chains.transpose(2, 0, 1))) return chains + def save_results(self, outdir): + # - training phase + name = os.path.join(outdir, f'results_training.npz') + print(f"Saving training samples to {name}") + state = self.Sampler.get_sampler_state(training=True) + chains, log_prob, local_accs, global_accs, loss_vals = state["chains"], state["log_prob"], state["local_accs"], state["global_accs"], state["loss_vals"] + local_accs = jnp.mean(local_accs, axis=0) + global_accs = jnp.mean(global_accs, axis=0) + jnp.savez(name, log_prob=log_prob, local_accs=local_accs, + global_accs=global_accs, loss_vals=loss_vals) + + # - production phase + name = os.path.join(outdir, f'results_production.npz') + print(f"Saving production samples to {name}") + state = self.Sampler.get_sampler_state(training=False) + chains, log_prob, local_accs, global_accs = state["chains"], state["log_prob"], state["local_accs"], state["global_accs"] + local_accs = jnp.mean(local_accs, axis=0) + global_accs = jnp.mean(global_accs, axis=0) + jnp.savez(name, chains=chains, log_prob=log_prob, + local_accs=local_accs, global_accs=global_accs) + def save_hyperparameters(self, outdir): # Convert step_size to list for JSON formatting @@ -257,29 +278,25 @@ def plot_lightcurves(self, zorder = 2 # Predict and convert to apparent magnitudes - mag_bestfit = self.likelihood.model.predict(best_fit_params_named) - d = best_fit_params_named["luminosity_distance"] - for filt in filters: - mag_bestfit[filt] = mag_app_from_mag_abs(mag_bestfit[filt], d) + time_obs, mag_bestfit = self.likelihood.model.predict(best_fit_params_named) + for i, filter_name in enumerate(filters): ax = plt.subplot(len(filters), 1, i + 1) mag = mag_bestfit[filter_name] - t = self.likelihood.model.times - mask = (t >= tmin) & (t <= tmax) - ax.plot(t[mask], mag[mask], color = "blue", label = "Best fit", zorder = zorder) + mask = (time_obs >= tmin) & (time_obs <= tmax) + ax.plot(time_obs[mask], mag[mask], color = "blue", label = "Best fit", zorder = zorder) # Other samples zorder = 1 for sample in samples.T: sample_named = self.prior.add_name(sample) sample_named.update(self.likelihood.fixed_params) - mag = self.likelihood.model.predict(sample_named) - d = sample_named["luminosity_distance"] - for filt in filters: - mag[filt] = mag_app_from_mag_abs(mag[filt], d) + time_obs, mag = self.likelihood.model.predict(sample_named) + mask = (time_obs >= tmin) & (time_obs <= tmax) + for i, filter_name in enumerate(filters): ax = plt.subplot(len(filters), 1, i + 1) - ax.plot(self.likelihood.model.times[mask], mag[filter_name][mask], color = "gray", alpha = 0.05, zorder = zorder) + ax.plot(time_obs[mask], mag[filter_name][mask], color = "gray", alpha = 0.05, zorder = zorder) ### Make pretty for i, filter_name in enumerate(filters): diff --git a/src/fiesta/inference/injection.py b/src/fiesta/inference/injection.py index 7b34aaf..338634e 100644 --- a/src/fiesta/inference/injection.py +++ b/src/fiesta/inference/injection.py @@ -2,19 +2,16 @@ # TODO: for now, we will only support creating injections from a given model import argparse -import copy -import numpy as np -import jax -import jax.numpy as jnp + +import h5py from jaxtyping import Float, Array +import numpy as np from fiesta.inference.lightcurve_model import LightcurveModel -from fiesta.conversions import mag_app_from_mag_abs +from fiesta.conversions import mag_app_from_mag_abs, apply_redshift from fiesta.utils import Filter -from fiesta.constants import days_to_seconds, c -from fiesta import conversions -from fiesta.train.AfterglowData import RunAfterglowpy +from fiesta.train.AfterglowData import RunAfterglowpy, RunPyblastafterglow # TODO: get the parser going def get_parser(**kwargs): @@ -24,124 +21,216 @@ def get_parser(**kwargs): description="Inference on kilonova and GRB parameters.", add_help=add_help, ) - -class InjectionRecovery: - - def __init__(self, - model: LightcurveModel, - injection_dict: dict[str, Float], - filters: list[str] = None, + +class InjectionBase: + + def __init__(self, + filters: list[str], + trigger_time: float, tmin: Float = 0.1, - tmax: Float = 14.0, + tmax: Float = 1000.0, N_datapoints: int = 10, + t_detect: dict[str, Array] = None, error_budget: Float = 1.0, - randomize_nondetections: bool = False, - randomize_nondetections_fraction: Float = 0.2): + nondetections: bool = False, + nondetections_fraction: Float = 0.2): - self.model = model - # Ensure given filters are also in the trained model - if filters is None: - filters = model.filters - else: - for filt in filters: - if filt not in model.filters: - print(f"Filter {filt} not in model filters. Removing from list") - filters.remove(filt) - + self.Filters = [Filter(filt) for filt in filters] print(f"Creating injection with filters: {filters}") - self.filters = filters - self.injection_dict = injection_dict - self.tmin = tmin - self.tmax = tmax - self.N_datapoints = N_datapoints + self.trigger_time = trigger_time + + if t_detect is not None: + self.t_detect = t_detect + else: + self.create_t_detect(tmin, tmax, N_datapoints) + self.error_budget = error_budget - self.randomize_nondetections = randomize_nondetections - self.randomize_nondetections_fraction = randomize_nondetections_fraction + self.nondetections = nondetections + self.nondetections_fraction = nondetections_fraction + + def create_t_detect(self, tmin, tmax, N): + """Create a time grid for the injection data.""" + + self.t_detect = {} + points_list = np.random.multinomial(N, [1/len(self.Filters)]*len(self.Filters)) # random number of time points in each filter + + for points, Filt in zip(points_list, self.Filters): + t = np.exp(np.random.uniform(np.log(tmin), np.log(tmax), size = points)) + t = np.sort(t) + t[::2] *= np.random.uniform(1, (tmax/tmin)**(1/points), size = len(t[::2])) # correlate the time points + t[::3] *= np.random.uniform(1, (tmax/tmin)**(1/points), size = len(t[::3])) # correlate the time points + t = np.minimum(t, tmax) + self.t_detect[Filt.name] = np.sort(t) + + def create_injection(self, + injection_dict: dict[str, Float]): + raise NotImplementedError + + def randomize_nondetections(self,): + if not self.nondetections: + return + + N = np.sum([len(self.t_detect[Filt.name]) for Filt in self.Filters]) + nondets_list = np.random.multinomial(int(N*self.nondetections_fraction), [1/len(self.Filters)]*len(self.Filters)) # random number of non detections in each filter + + for nondets, Filt in zip(nondets_list, self.Filters): + inds = np.random.choice(np.arange(len(self.data[Filt.name])), size=nondets, replace=False) + self.data[Filt.name][inds] += np.array([0, -5., np.inf]) + + + + + +class InjectionSurrogate(InjectionBase): + + def __init__(self, + model: LightcurveModel, + *args, + **kwargs): - def create_injection(self): + self.model = model + super().__init__(*args, **kwargs) + + def create_injection(self, injection_dict): """Create a synthetic injection from the given model and parameters.""" + + injection_dict["luminosity_distance"] = injection_dict.get('luminosity_distance', 1e-5) + injection_dict["redshift"] = injection_dict.get('redshift', 0) + times, mags = self.model.predict(injection_dict) self.data = {} - all_mag_abs = self.model.predict(self.injection_dict) - - for filt in self.filters: - times = self.create_timegrid() - all_mag_app = mag_app_from_mag_abs(all_mag_abs[filt], self.injection_dict["luminosity_distance"]) - mag_app = np.interp(times, self.model.times, all_mag_app) - mag_err = self.error_budget * np.ones_like(times) - - # Randomize to get some non-detections if so desired: - if self.randomize_nondetections: - n_nondetections = int(self.randomize_nondetections_fraction * len(times)) - nondet_indices = np.random.choice(len(times), size = n_nondetections, replace = False) - - mag_app[nondet_indices] -= 5.0 # randomly bump down the magnitude - mag_err[nondet_indices] = np.inf + + for Filt in self.Filters: + t_detect = self.t_detect[Filt.name] + + mag_app = np.interp(t_detect, times, mags[Filt.name]) + + mag_err = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect))) + mag_err = np.maximum(mag_err, 0.01) + mag_err = np.minimum(mag_err, 1) - array = np.array([times, mag_app, mag_err]).T - self.data[filt] = array + array = np.array([t_detect, mag_app, mag_err]).T + self.data[Filt.name] = array + + self.randomize_nondetections() + +class InjectionAfterglowpy(InjectionBase): - def create_timegrid(self): - """Create a time grid for the injection.""" + def __init__(self, + jet_type: int = -1, + *args, + **kwargs): + + self.jet_type = jet_type + super().__init__(*args, **kwargs) - # TODO: create more interesting grids than uniform and same accross all filters? - return np.linspace(self.tmin, self.tmax, self.N_datapoints) + def create_injection(self, injection_dict): + """Create a synthetic injection from the given model and parameters.""" + + nus = [nu for Filter in self.Filters for nu in Filter.nus] + times = [t for Filter in self.Filters for t in self.t_detect[Filter.name]] + + nus = np.sort(nus) + times = np.sort(times) + + afgpy = RunAfterglowpy(self.jet_type, times, nus, [list(injection_dict.values())], injection_dict.keys()) + _, log_flux = afgpy(0) + mJys = np.exp(log_flux).reshape(len(nus), len(times)) + + self.data = {} + + for Filter in self.Filters: + t_detect = self.t_detect[Filter.name] + mag_abs = Filter.get_mag(mJys, nus) # even when 'luminosity_distance' is passed to RunAfterglowpy, it will return the abs mag (with redshift) + mag_app = mag_app_from_mag_abs(mag_abs, injection_dict["luminosity_distance"]) + mag_app = np.interp(t_detect, times, mag_app) + mag_err = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect))) + mag_err = np.maximum(mag_err, 0.01) + mag_err = np.minimum(mag_err, 1) -class InjectionRecoveryAfterglowpy: + self.data[Filter.name] = np.array([t_detect + self.trigger_time, mag_app, mag_err]).T + + self.randomize_nondetections() + +class InjectionPyblastafterglow(InjectionBase): def __init__(self, - injection_dict: dict[str, Float], - trigger_time: Float, - filters: list[str], - jet_type = -1, - tmin: Float = 0.1, - tmax: Float = 1000.0, - N_datapoints: int = 10, - error_budget: Float = 1.0, - randomize_nondetections: bool = False, - randomize_nondetections_fraction: Float = 0.2): + jet_type: str = "tophat", + *args, + **kwargs): self.jet_type = jet_type - # Ensure given filters are also in the trained model + super().__init__(*args, **kwargs) - if filters is None: - filters = model.filters + def create_injection(self, injection_dict): + """Create a synthetic injection from the given model and parameters.""" - self.filters = [Filter(filt) for filt in filters] - print(f"Creating injection with filters: {filters}") - self.injection_dict = injection_dict - self.trigger_time = trigger_time - self.tmin = tmin - self.tmax = tmax - self.N_datapoints = N_datapoints - self.error_budget = error_budget - self.randomize_nondetections = randomize_nondetections - self.randomize_nondetections_fraction = randomize_nondetections_fraction + nus = [nu for Filter in self.Filters for nu in Filter.nus] + times = [t for Filter in self.Filters for t in self.t_detect[Filter.name]] + + nus = np.sort(nus) + times = np.sort(times) + nus = np.logspace(np.log10(nus[0]), np.log10(nus[-1]), len(nus)) #pbag only takes log (or linear) spaced arrays + nus = np.logspace(np.log10(times[0]), np.log10(times[-1]), len(times)) + + pbag = RunPyblastafterglow(self.jet_type, times, nus, [list(injection_dict.values())], injection_dict.keys()) + _, log_flux = pbag(0) + mJys = np.exp(log_flux).reshape(len(nus), len(times)) + + self.data = {} + + for Filter in self.Filters: + t_detect = self.t_detect[Filter.name] + + mag_abs = Filter.get_mag(mJys, nus) + mag_app = mag_app_from_mag_abs(mag_abs, injection_dict["luminosity_distance"]) + mag_app = np.interp(t_detect, times, mag_app) + + mag_err = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect))) + mag_err = np.maximum(mag_err, 0.01) + mag_err = np.minimum(mag_err, 1) + + self.data[Filter.name] = np.array([t_detect + self.trigger_time, mag_app, mag_err]).T - def create_injection(self): - """Create a synthetic injection from the given model and parameters.""" + self.randomize_nondetections() + + def create_injection_from_file(self, file, injection_dict): + with h5py.File(file) as f: + times = f["times"][:] + nus = f["nus"][:] + parameter_names = f["parameter_names"][:].astype(str).tolist() + test_X_raw = f["test"]["X"][:] + + X = np.array([injection_dict[p] for p in parameter_names]) + + ind = np.argmin(np.sum( ( (test_X_raw - X)/(np.max(test_X_raw, axis=0) - np.min(test_X_raw, axis=0)) )**2, axis=1)) + X = test_X_raw[ind] + + log_flux = f["test"]["y"][ind] - nus = [filt.nu for filt in self.filters] - times = np.logspace(np.log10(self.tmin), np.log10(self.tmax), 200) - afgpy = RunAfterglowpy(self.jet_type, times, nus, [list(self.injection_dict.values())], self.injection_dict.keys()) - _, log_flux = afgpy(0) - mJys = np.exp(log_flux).reshape(len(nus), 200) + print(f"Found suitable injection with {dict(zip(parameter_names, X))}") + mJys = np.exp(log_flux).reshape(len(nus), len(times)) + mJys, times, nus = apply_redshift(mJys, times, nus, injection_dict.get("redshift", 0.0)) self.data = {} - points = np.random.multinomial(self.N_datapoints, [1/len(self.filters)]*len(self.filters)) # random number of datapoints in each filter - for j, npoints, filt in zip(range(len(self.filters)), points, self.filters): - times_data = self.create_timegrid(npoints) - mJys_filter = np.interp(times_data, times, mJys[j]) - magnitudes = conversions.mJys_to_mag_np(mJys_filter) - magnitudes = magnitudes + 5 * np.log10(self.injection_dict["luminosity_distance"]/(10*1e-6)) - - mag_err = self.error_budget * np.ones_like(times_data) - self.data[filt.name] = np.array([times_data + self.trigger_time, magnitudes, mag_err]).T - - def create_timegrid(self, npoints): - """Create a time grid for the injection.""" - return np.linspace(self.tmin, self.tmax, npoints) \ No newline at end of file + for Filter in self.Filters: + t_detect = self.t_detect[Filter.name] + + mag_abs = Filter.get_mag(mJys, nus) + mag_app = mag_app_from_mag_abs(mag_abs, injection_dict["luminosity_distance"]) + mag_app = np.interp(t_detect, times, mag_app) + + mag_err = self.error_budget * np.sqrt(np.random.chisquare(df=1, size = len(t_detect))) + mag_err = np.maximum(mag_err, 0.01) + mag_err = np.minimum(mag_err, 1) + + self.data[Filter.name] = np.array([t_detect + self.trigger_time, mag_app, mag_err]).T + + self.randomize_nondetections() + return dict(zip(parameter_names, X)) + + diff --git a/src/fiesta/inference/lightcurve_model.py b/src/fiesta/inference/lightcurve_model.py index 3956c35..5f81aa5 100644 --- a/src/fiesta/inference/lightcurve_model.py +++ b/src/fiesta/inference/lightcurve_model.py @@ -7,36 +7,68 @@ import jax.numpy as jnp from jaxtyping import Array, Float from functools import partial -from beartype import beartype as typechecker from flax.training.train_state import TrainState import pickle import fiesta.train.neuralnets as fiesta_nn -from fiesta.utils import MinMaxScalerJax, inverse_svd_transform -from fiesta import models_utilities +from fiesta.conversions import mag_app_from_mag_abs, apply_redshift from fiesta import utils + ######################## ### ABSTRACT CLASSES ### ######################## -class LightcurveModel: - """Abstract class for general light curve models""" +class SurrogateModel: + """Abstract class for general surrogate models""" - name: str + name: str + directory: str filters: list[str] parameter_names: list[str] times: Array def __init__(self, - name: str) -> None: + name: str, + directory: str) -> None: self.name = name + self.directory = directory + + self.load_metadata() + self.filters = [] - self.parameter_names = [] - self.times = jnp.array([]) def add_name(self, x: Array): - return dict(zip(self.parameter_names, x)) + return dict(zip(self.parameter_names, x)) + + def load_metadata(self) -> None: + print(f"Loading metadata for model {self.name}.") + self.metadata_filename = os.path.join(self.directory, f"{self.name}_metadata.pkl") + assert os.path.exists(self.metadata_filename), f"Metadata file {self.metadata_filename} not found - check the directory {self.directory}" + + # open the file + with open(self.metadata_filename, "rb") as meta_file: + self.metadata = pickle.load(meta_file) + + # make the scaler objects attributes + self.X_scaler = self.metadata["X_scaler"] + self.y_scaler = self.metadata["y_scaler"] + + # load parameter names + self.parameter_names = self.metadata["parameter_names"] + print(f"This surrogate {self.name} should only be used in the following parameter ranges:") + from ast import literal_eval + parameter_distributions = literal_eval(self.metadata["parameter_distributions"]) + for key in parameter_distributions.keys(): + print(f"\t {key}: {parameter_distributions[key][:2]}") + + #load times + self.times = self.metadata["times"] + + #load nus + if "nus" in self.metadata.keys(): + self.nus = self.metadata["nus"] + def project_input(self, x: Array) -> dict[str, Array]: """ @@ -77,10 +109,13 @@ def project_output(self, y: dict[str, Array]) -> dict[str, Array]: """ return y - @partial(jax.jit, static_argnums=(0,)) # TODO: the jit here can come into conflict with scikit-learn methods used in project_output, maybe only jit self.comput_output - def predict(self, x: dict[str, Array]) -> dict[str, Array]: + def convert_to_mag(self, y: Array, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: + raise NotImplementedError + + @partial(jax.jit, static_argnums=(0,)) + def predict(self, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: """ - Generate the lightcurve y from the unnormalized and untransformed input x. + Generate the apparent magnitudes from the unnormalized and untransformed input x. Chains the projections with the actual computation of the output. E.g. if the model is a trained surrogate neural network, they represent the map from x tilde to y tilde. The mappings from x to x tilde and y to y tilde take care of projections (e.g. SVD projections) and normalizations. @@ -89,37 +124,60 @@ def predict(self, x: dict[str, Array]) -> dict[str, Array]: x (dict[str, Array]): Input array, unnormalized and untransformed. Returns: - Array: Output dict[str, Array], i.e., the desired raw light curve per filter + times + mag (dict[str, Array]): The desired magnitudes per filter """ # Use saved parameter names to extract the parameters in the correct order into an array x_array = jnp.array([x[name] for name in self.parameter_names]) + + # apply the NN x_tilde = self.project_input(x_array) y_tilde = self.compute_output(x_tilde) y = self.project_output(y_tilde) - return y + + # convert the NN output to apparent magnitude + times, mag = self.convert_to_mag(y, x) + + return times, mag + + def predict_abs_mag(self, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: + x["luminosity_distance"] = 1e-5 + x["redshift"] = 0. + + return self.predict(x) + + def vpredict(self, X: dict[str, Array]) -> tuple[Array, dict[str, Array]]: + """ + Vectorized prediction function to calculate the apparent magnitudes for several inputs x at the same time. + """ + + X_array = jnp.array([X[name] for name in X.keys()]).T + + def predict_single(x): + param_dict = {key: x[j] for j, key in enumerate(X.keys())} + return self.predict(param_dict) + + times, mag_apps = jax.vmap(predict_single)(X_array) + + return times[0], mag_apps def __repr__(self) -> str: return self.name -class SurrogateLightcurveModel(LightcurveModel): - """Abstract class for models that rely on a surrogate, in the form of a neural network.""" +class LightcurveModel(SurrogateModel): + """Class of surrogate models that predicts the magnitudes per filter.""" directory: str metadata: dict - X_scaler: MinMaxScalerJax - y_scaler: dict[str, MinMaxScalerJax] + X_scaler: object + y_scaler: dict[str, object] models: dict[str, TrainState] - times: Array - tmin: Float - tmax: Float - parameter_names: list[str] def __init__(self, name: str, directory: str, - filters: list[str] = None, - times: Array = None) -> None: + filters: list[str] = None) -> None: """_summary_ Args: @@ -127,71 +185,38 @@ def __init__(self, directory (str): Directory with trained model states and projection metadata such as scalers. filters (list[str]): List of all the filters for which the model should be loaded. """ - super().__init__(name) - self.directory = directory - self.models = {} + super().__init__(name, directory) - # Load the metadata for projections etc - self.load_metadata() + # Load the filters and networks self.load_filters(filters) - self.load_scalers() - self.load_times(times) - self.load_parameter_names() self.load_networks() - def load_metadata(self) -> None: - self.metadata_filename = os.path.join(self.directory, f"{self.name}_metadata.pkl") - assert os.path.exists(self.metadata_filename), f"Metadata file {self.metadata_filename} not found - check the directory {self.directory}" - meta_file = open(self.metadata_filename, "rb") - self.metadata = pickle.load(meta_file) - meta_file.close() - - def load_filters(self, filters: list[str] = None) -> None: + def load_filters(self, filters_args: list[str] = None) -> None: # Save those filters that were given and that were trained and store here already pkl_files = [file for file in os.listdir(self.directory) if file.endswith(".pkl") or file.endswith(".pickle")] - all_available_filters = [file.split(".")[0] for file in pkl_files] + all_available_filters = [(file.split(".")[0]).split("_")[1] for file in pkl_files] - if filters is None: + if filters_args is None: # Use all filters that the surrogate model supports filters = all_available_filters else: # Fetch those filters specified by the user that are available - filters = [f.replace(":", "_") for f in filters] - filters = [f for f in filters if f in all_available_filters] + filters = [f for f in filters_args if f in all_available_filters] if len(filters) == 0: - raise ValueError(f"No filters found in {self.directory} that match the given filters {filters}") + raise ValueError(f"No filters found in {self.directory} that match the given filters {filters_args}.") self.filters = filters - print(f"Loaded SurrogateLightcurveModel with filters {filters}") - - def load_scalers(self): - self.X_scaler, self.y_scaler = {}, {} - for filt in self.filters: - self.X_scaler[filt] = MinMaxScalerJax(min_val=self.metadata[filt]["X_scaler_min"], max_val=self.metadata[filt]["X_scaler_max"]) - self.y_scaler[filt] = self.metadata[filt]["y_scaler"] - - - def load_times(self, times: Array = None) -> None: - if times is None: - times = jnp.array(self.metadata["times"]) - if times.min()self.metadata["times"].max(): - times = jnp.array(self.metadata["times"]) - self.times = times - self.tmin = jnp.min(times) - self.tmax = jnp.max(times) + self.Filters = [utils.Filter(filt) for filt in self.filters] + print(f"Loaded SurrogateLightcurveModel with filters {self.filters}.") def load_networks(self) -> None: self.models = {} for filter in self.filters: - filename = os.path.join(self.directory, f"{filter}.pkl") - state, _ = fiesta_nn.load_model(filename) + filename = os.path.join(self.directory, f"{self.name}_{filter}.pkl") + state, _ = fiesta_nn.MLP.load_model(filename) self.models[filter] = state - - def load_parameter_names(self) -> None: - """Implement in child classes""" - raise NotImplementedError - def project_input(self, x: Array) -> dict[str, Array]: + def project_input(self, x: Array) -> Array: """ Project the given input to whatever preprocessed input space we are in. @@ -201,10 +226,10 @@ def project_input(self, x: Array) -> dict[str, Array]: Returns: dict[str, Array]: Transformed input array """ - x_tilde = {filter: self.X_scaler[filter].transform(x) for filter in self.filters} + x_tilde = self.X_scaler.transform(x) return x_tilde - def compute_output(self, x: dict[str, Array]) -> dict[str, Array]: + def compute_output(self, x: Array) -> Array: """ Apply the trained flax neural network on the given input x. @@ -214,8 +239,13 @@ def compute_output(self, x: dict[str, Array]) -> dict[str, Array]: Returns: dict[str, Array]: _description_ """ - # TODO: too convoluted, simplify - return {filter: self.models[filter].apply_fn({'params': self.models[filter].params}, x[filter]) for filter in self.filters} + def apply_model(filter): + model = self.models[filter] + output = model.apply_fn({'params': model.params}, x) + return output + + y = jax.tree.map(apply_model, self.filters) # avoid for loop with jax.tree.map + return dict(zip(self.filters, y)) def project_output(self, y: dict[str, Array]) -> dict[str, Array]: """ @@ -227,77 +257,33 @@ def project_output(self, y: dict[str, Array]) -> dict[str, Array]: Returns: dict[str, Array]: Output array transformed to the preprocessed space. """ - return {filter: self.y_scaler[filter].inverse_transform(y[filter]) for filter in self.filters} - -class SVDSurrogateLightcurveModel(SurrogateLightcurveModel): - - VA: dict[str, Array] - svd_ncoeff: int - - def __init__(self, - name: str, - directory: str, - filters: list[str] = None, - times: Array = None): - """ - Initialize a class to generate lightcurves from a Bulla trained model. - - """ - super().__init__(name=name, directory=directory, times=times, filters=filters) + def inverse_transform(filter): + y_scaler = self.y_scaler[filter] + output = y_scaler.inverse_transform(y[filter]) + return output - self.VA = {filt: self.metadata[filt]["VA"] for filt in filters} - self.svd_ncoeff = {filt: self.metadata[filt]["svd_ncoeff"] for filt in filters} - - def load_parameter_names(self) -> None: - raise NotImplementedError - - def project_output(self, y: dict[str, Array]) -> dict[str, Array]: - """ - Apply the trained flax neural network on the given input x. + y = jax.tree.map(inverse_transform, self.filters) # avoid for loop with jax.tree.map + return jnp.array(y) + + def convert_to_mag(self, y: Array, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: + mag_abs = y + mag_app = mag_app_from_mag_abs(mag_abs, x["luminosity_distance"]) + return self.times, dict(zip(self.filters, mag_app)) - Args: - x (dict[str, Array]): Input array of parameters +class FluxModel(SurrogateModel): + """Class of surrogate models that predicts the 2D spectral flux density array.""" - Returns: - dict[str, Array]: _description_ - """ - output = {filter: inverse_svd_transform(y[filter], self.VA[filter], self.svd_ncoeff[filter]) for filter in self.filters} - return super().project_output(output) - -class BullaLightcurveModel(SVDSurrogateLightcurveModel): - - def __init__(self, - name: str, - directory: str, - filters: list[str] = None, - times: Array = None): - - super().__init__(name=name, directory=directory, filters=filters, times=times) - - def load_parameter_names(self) -> None: - self.parameter_names = models_utilities.BULLA_PARAMETER_NAMES[self.name] - -class AfterglowpyLightcurvemodel(SVDSurrogateLightcurveModel): - def __init__(self, name: str, directory: str, - filters: list[str] = None, - times: Array = None): - super().__init__(name=name, directory=directory, filters=filters, times=times) - - def load_parameter_names(self) -> None: - self.parameter_names = self.metadata["parameter_names"] - -class PCALightcurveModel(SurrogateLightcurveModel): + filters: list[str] = None, + model_type: str = "MLP"): + self.model_type = model_type # TODO: make this switch nicer somehow maybe + super().__init__(name, directory) - def __init__(self, - name: str, - directory: str, - filters: list[str] = None, - times: Array = None): - - super().__init__(name = name, directory= directory, filters = filters, times = times) + # Load the filters and networks + self.load_filters(filters) + self.load_networks() def load_filters(self, filters: list[str] = None) -> None: self.nus = self.metadata['nus'] @@ -316,18 +302,20 @@ def load_filters(self, filters: list[str] = None) -> None: raise ValueError(f"No filters found that match the trained frequency range {self.nus[0]:.3e} Hz to {self.nus[-1]:.3e} Hz.") print(f"Loaded SurrogateLightcurveModel with filters {self.filters}.") - - def load_scalers(self): - self.X_scaler = self.metadata["X_scaler"] - self.y_scaler = self.metadata["y_scaler"] - #self.pca = self.metadata["pca"] def load_networks(self) -> None: filename = os.path.join(self.directory, f"{self.name}.pkl") - state, _ = fiesta_nn.load_model(filename) + if self.model_type == "MLP": + state, _ = fiesta_nn.MLP.load_model(filename) + latent_dim = 0 + elif self.model_type == "CVAE": + state, _ = fiesta_nn.CVAE.load_model(filename) + latent_dim = state.params["layers_0"]["kernel"].shape[0] - len(self.parameter_names) + else: + raise ValueError(f"Model type must be either 'MLP' or 'CVAE'.") + self.latent_vector = jnp.array(jnp.zeros(latent_dim)) # TODO: how to get latent vector? self.models = state - def project_input(self, x: Array) -> Array: """ Project the given input to whatever preprocessed input space we are in. @@ -351,6 +339,7 @@ def compute_output(self, x: Array) -> Array: Returns: dict[str, Array]: _description_ """ + x = jnp.concatenate((self.latent_vector, x)) output = self.models.apply_fn({'params': self.models.params}, x) return output @@ -364,25 +353,25 @@ def project_output(self, y: Array) -> dict[str, Array]: Returns: dict[str, Array]: Output array transformed to the preprocessed space. """ - #y = self.pca.inverse_transform(y) y = self.y_scaler.inverse_transform(y) - - y = jnp.reshape(y, (len(self.metadata["nus"]), len(self.times))) - y = jnp.exp(y) + y = jnp.reshape(y, (len(self.nus), len(self.times))) - def compute_mag_single_filter(nu): - # TODO: get a check here that the filt.nu is in range of the meta data - lambda_interp = lambda column: jnp.interp(nu, self.metadata["nus"], column) - mJys = jax.vmap(lambda_interp)(y.T) - mag = -48.6 + -1 * jnp.log10(mJys) * 2.5 + -1 * (-26) * 2.5 - return mag + return y + + def convert_to_mag(self, y: Array, x: dict[str, Array]) -> tuple[Array, dict[str, Array]]: + + mJys = jnp.exp(y) + + mJys_obs, times_obs, nus_obs = apply_redshift(mJys, self.times, self.nus, x["redshift"]) + # TODO: Add EBL table here at some point + + mag_abs = jax.tree.map(lambda Filter: Filter.get_mag(mJys_obs, nus_obs), + self.Filters) + mag_abs = jnp.array(mag_abs) - filter_nus = jnp.array([filt.nu for filt in self.Filters]) - output_array = jax.vmap(compute_mag_single_filter)(filter_nus) - output = dict(zip(self.filters, output_array)) + mag_app = mag_app_from_mag_abs(mag_abs, x["luminosity_distance"]) - return output - + return times_obs, dict(zip(self.filters, mag_app)) def predict_log_flux(self, x: Array) -> Array: """ @@ -394,21 +383,34 @@ def predict_log_flux(self, x: Array) -> Array: Returns: log_flux [Array]: Array of log-fluxes. """ - x_tilde = self.X_scaler.transform(x) + x_tilde = jnp.concatenate((self.latent_vector, x_tilde)) y = self.models.apply_fn({'params': self.models.params}, x_tilde) logflux = self.y_scaler.inverse_transform(y) + logflux = logflux.reshape(len(self.nus), len(self.times)) return logflux -class AfterglowpyPCA(PCALightcurveModel): + +################# +# MODEL CLASSES # +################# + +class BullaLightcurveModel(LightcurveModel): + + def __init__(self, + name: str, + directory: str, + filters: list[str] = None): + + super().__init__(name=name, directory=directory, filters=filters) + +class AfterglowFlux(FluxModel): def __init__(self, name: str, directory: str, filters: list[str] = None, - times: Array = None): - super().__init__(name=name, directory=directory, filters=filters, times=times) - - def load_parameter_names(self) -> None: - self.parameter_names = self.metadata["parameter_names"] \ No newline at end of file + model_type: str = "MLP"): + super().__init__(name=name, directory=directory, filters=filters, model_type=model_type) + \ No newline at end of file diff --git a/src/fiesta/inference/likelihood.py b/src/fiesta/inference/likelihood.py index 8e7caf8..dba1464 100644 --- a/src/fiesta/inference/likelihood.py +++ b/src/fiesta/inference/likelihood.py @@ -8,7 +8,6 @@ from fiesta.inference.lightcurve_model import LightcurveModel from fiesta.utils import truncated_gaussian -from fiesta.conversions import mag_app_from_mag_abs class EMLikelihood: @@ -79,7 +78,6 @@ def __init__(self, print("Loading and preprocessing observations in likelihood . . .") processed_data = copy.deepcopy(data) - processed_data = {k.replace(":", "_"): v for k, v in processed_data.items()} for filt in self.filters: if filt not in processed_data: @@ -135,15 +133,13 @@ def evaluate(self, """ theta = {**theta, **self.fixed_params} - mag_abs: dict[str, Array] = self.model.predict(theta) - mag_app = jax.tree_util.tree_map(lambda x: mag_app_from_mag_abs(x, theta["luminosity_distance"]), - mag_abs) + times, mag_app = self.model.predict(theta) # Interpolate the mags to the times of interest - mag_est_det = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m), + mag_est_det = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, times, m, left = "extrapolate", right = "extrapolate"), # TODO extrapolation is maybe problematic here self.times_det, mag_app) - mag_est_nondet = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, self.model.times, m), + mag_est_nondet = jax.tree_util.tree_map(lambda t, m: jnp.interp(t, times, m, left = "extrapolate", right = "extrapolate"), self.times_nondet, mag_app) # Get chisq diff --git a/src/fiesta/train/AfterglowData.py b/src/fiesta/train/AfterglowData.py index 80b1c6d..1fb1d64 100644 --- a/src/fiesta/train/AfterglowData.py +++ b/src/fiesta/train/AfterglowData.py @@ -110,10 +110,19 @@ def _read_file(self,): self.times = f["times"][:] self.nus = f["nus"][:] self.parameter_names = f["parameter_names"][:].astype(str).tolist() - self.n_training_exists = (f["train"]["X"].shape)[0] - self.n_val_exists = (f["val"]["X"].shape)[0] - self.n_test_exists = (f["test"]["X"].shape)[0] self.parameter_distributions = ast.literal_eval(f["parameter_distributions"][()].decode('utf-8')) + try: + self.n_training_exists = (f["train"]["X"].shape)[0] + except KeyError: + self.n_training_exists = 0 + try: + self.n_val_exists = (f["val"]["X"].shape)[0] + except KeyError: + self.n_val_exists = 0 + try: + self.n_test_exists = (f["test"]["X"].shape)[0] + except KeyError: + self.n_test_exists = 0 def create_raw_data(self, n: int, training: bool = True): """ @@ -174,12 +183,12 @@ def create_special_data(self, X_raw, label:str, comment: str = None): X, y = self.fix_nans(X,y) self._save_to_file(X, y, "special_train", label = label, comment= comment) - def run_afterglow_model(X): + def run_afterglow_model(self, X): raise NotImplementedError def _save_to_file(self, X, y, group: str, label: str = None, comment: str = None): with h5py.File(self.outfile, "a") as f: - if "y" in f[group]: + if "y" in f[group]: # checks if the dataset already exists Xset = f[group]["X"] Xset.resize(Xset.shape[0]+X.shape[0], axis = 0) Xset[-X.shape[0]:] = X @@ -188,7 +197,7 @@ def _save_to_file(self, X, y, group: str, label: str = None, comment: str = None yset.resize(yset.shape[0]+y.shape[0], axis = 0) yset[-y.shape[0]:] = y - elif label is not None: # when we have special training data + elif label is not None: # or if we have special training data if label in f["special_train"]: Xset = f["special_train"][label]["X"] Xset.resize(Xset.shape[0]+X.shape[0], axis = 0) @@ -205,7 +214,7 @@ def _save_to_file(self, X, y, group: str, label: str = None, comment: str = None f["special_train"][label].create_dataset("X", data = X, maxshape=(None, len(self.parameter_names)), chunks = (self.chunk_size, len(self.parameter_names))) f["special_train"][label].create_dataset("y", data = y, maxshape=(None, len(self.times)*len(self.nus)), chunks = (self.chunk_size, len(self.times)*len(self.nus))) - else: + else: # or if we need to create a new data set f[group].create_dataset("X", data = X, maxshape=(None, len(self.parameter_names)), chunks = (self.chunk_size, len(self.parameter_names))) f[group].create_dataset("y", data = y, maxshape=(None, len(self.times)*len(self.nus)), chunks = (self.chunk_size, len(self.times)*len(self.nus))) @@ -236,34 +245,74 @@ def run_afterglow_model(self, X): class PyblastafterglowData(AfterglowData): - def __init__(self, path_to_exec: str, rank: int = 0, grb_resolution: int = 12, *args, **kwargs): + def __init__(self, path_to_exec: str, pbag_kwargs: dict = None, rank: int = 0, *args, **kwargs): self.outfile = f"pyblastafterglow_raw_data_{rank}.h5" self.chunk_size = 10 self.rank = rank self.path_to_exec = path_to_exec - self.grb_resolution = grb_resolution + self.pbag_kwargs = pbag_kwargs super().__init__(*args, **kwargs) def run_afterglow_model(self, X): """Should be run in parallel with different mpi processes to run pyblastafterglow on the parameters in the array X.""" y = np.empty((len(X), len(self.times)*len(self.nus))) - pbag = RunPyblastafterglow(self.jet_type, self.times, self.nus, X, self.parameter_names, self.fixed_parameters, rank=self.rank, path_to_exec = self.path_to_exec, grb_resolution = self.grb_resolution) - for j in tqdm.tqdm(range(len(X)), desc = f"Computing {len(X)} pyblastafterglow calculations.", leave = False): + + pbag = RunPyblastafterglow(self.jet_type, + self.times, + self.nus, + X, + self.parameter_names, + self.fixed_parameters, + rank=self.rank, + path_to_exec=self.path_to_exec, + **self.pbag_kwargs) + + for j in range(len(X)): try: idx, out = pbag(j) y[idx] = out except: try: - pbag.n_tb = 3000 # increase blast wave evolution time grid if there is an error + # increase blast wave evolution time grid if there is an error + pbag.ntb = 3000 idx, out = pbag(j) y[idx] = out - pbag.n_tb = 1000 + pbag.ntb = self.pbag_kwargs["ntb"] except: - y[j] = np.full(len(self.times)*len(self.nus), np.nan) - + y[j] = np.full(len(self.times)*len(self.nus), np.nan) return X, y + + def supplement_time(self,t_supp): + self.times = t_supp + + for group in ["train", "val", "test"]: + with h5py.File(self.outfile) as f: + if "y" not in f[group].keys(): + continue + if f[group]["y"].shape[1]>f["times"].shape[0] * f["nus"].shape[0]: + continue + X = f[group]["X"][:] + + _, y_new = self.run_afterglow_model(X) + y_new = y_new.reshape(-1, len(self.nus), len(self.times)) + + with h5py.File(self.outfile, "r+") as f: + y_old = f[group]["y"][:] + y_old = y_old.reshape(-1, f["nus"].shape[0], f["times"].shape[0]) + y = np.concatenate((y_new, y_old), axis=-1) + + new_time_shape = len(self.times) + f["times"].shape[0] + y = y.reshape(-1, new_time_shape * len(self.nus)) + del f[group]["y"] + f[group].create_dataset("y", data=y, maxshape=(None, new_time_shape*len(self.nus)), chunks = (self.chunk_size, new_time_shape*len(self.nus)) ) + + with h5py.File(self.outfile,"r+") as f: + t_old = f["times"][:] + del f["times"] + time = np.concatenate((t_supp, t_old)) + f.create_dataset("times", data=time) class RunAfterglowpy: def __init__(self, jet_type, times, nus, X, parameter_names, fixed_parameters = {}): @@ -290,8 +339,9 @@ def _call_afterglowpy(self, Z["jetType"] = params_dict.get("jetType", self.jet_type) Z["specType"] = params_dict.get("specType", 0) - Z["z"] = params_dict.get("z", 0.0) + Z["z"] = params_dict.get("redshift", 0.0) Z["xi_N"] = params_dict.get("xi_N", 1.0) + Z["counterjet"] = True Z["E0"] = 10 ** params_dict["log10_E0"] Z["n0"] = 10 ** params_dict["log10_n0"] @@ -334,11 +384,26 @@ def __call__(self, idx): class RunPyblastafterglow: - def __init__(self, jet_type, times, nus, X, parameter_names, fixed_parameters = {}, rank = 0, path_to_exec = "./pba.out", grb_resolution = 12): - self.jet_type = jet_type + def __init__(self, + jet_type: int, + times, + nus, + X, + parameter_names, + fixed_parameters={}, + rank = 0, + path_to_exec: str="./pba.out", + grb_resolution: int=12, + ntb: int=1000, + tb0: float=1e1, + tb1: float=1e11, + rtol: float=1e-1, + loglevel: str="err", + ): + jet_conversion = {"-1": "tophat", "0": "gaussian"} - self.jet_type = jet_conversion[str(self.jet_type)] + self.jet_type = jet_conversion[str(jet_type)] times_seconds = times * days_to_seconds # pyblastafterglow takes seconds as input # preparing the pyblastafterglow string argument for time array @@ -360,7 +425,11 @@ def __init__(self, jet_type, times, nus, X, parameter_names, fixed_parameters = self.rank = rank self.path_to_exec = path_to_exec self.grb_resolution = grb_resolution - self.n_tb = 1000 # set default blast wave evolution timegrid to 1000 + self.ntb = ntb + self.tb0 = tb0 + self.tb1 = tb1 + self.rtol = rtol + self.loglevel = loglevel def _call_pyblastafterglow(self, params_dict: dict[str, float]): @@ -401,12 +470,12 @@ def _call_pyblastafterglow(self, # main model parameters; Uniform ISM -- 2 free parameters main=dict( d_l= 3.086e19, # luminocity distance to the source [cm], fix at 10 pc, so that AB magnitude equals absolute magnitude - z = 0.0, # redshift of the source (used in Doppler shifring and EBL table) + z = params_dict.get("redshift", 0.0), # redshift of the source (used in Doppler shifring and EBL table) n_ism=np.power(10, params_dict["log10_n0"]), # ISM density [cm^-3] (assuming uniform) theta_obs= params_dict["inclination_EM"], # observer angle [rad] (from pol to jet axis) lc_freqs= self.lc_freqs, # frequencies for light curve calculation lc_times= self.lc_times, # times for light curve calculation - tb0=1e1, tb1=1e11, ntb=self.n_tb, # burster frame time grid boundary, resolution, for the simulation + tb0=self.tb0, tb1=self.tb1, ntb=self.ntb, # burster frame time grid boundary, resolution, for the simulation ), # ejecta parameters; FS only -- 3 free parameters @@ -416,7 +485,7 @@ def _call_pyblastafterglow(self, eps_b_fs=np.power(10, params_dict["log10_epsilon_B"]), # microphysics - FS - frac. energy in magnetic fields p_fs= params_dict["p"], # microphysics - FS - slope of the injection electron spectrum do_lc='yes', # task - compute light curves - rtol_theta = 1e-1, + rtol_theta = self.rtol, # save_spec='yes' # save comoving spectra # method_synchrotron_fs = 'Joh06', # method_ne_fs = 'usenprime', @@ -428,9 +497,10 @@ def _call_pyblastafterglow(self, P=P, # all parameters run=True, # run code itself (if False, it will try to load results) path_to_cpp=self.path_to_exec, # absolute path to the C++ executable of the code - loglevel="err", # logging level of the code (info or err) + loglevel=self.loglevel, # logging level of the code (info or err) process_skymaps=False # process unstractured sky maps. Only useed if `do_skymap = yes` ) + mJys = pba_run.GRB.get_lc() return mJys diff --git a/src/fiesta/train/Benchmarker.py b/src/fiesta/train/Benchmarker.py new file mode 100644 index 0000000..00691e8 --- /dev/null +++ b/src/fiesta/train/Benchmarker.py @@ -0,0 +1,215 @@ +from fiesta.inference.lightcurve_model import LightcurveModel, FluxModel + +import os +import ast +import h5py +import numpy as np +import matplotlib.pyplot as plt +import matplotlib.colors as colors +from matplotlib.cm import ScalarMappable + +from scipy.integrate import trapezoid +from scipy.interpolate import interp1d + +class Benchmarker: + + def __init__(self, + model: LightcurveModel, + data_file: str, + filters: list = None, + outdir: str = "./benchmarks", + metric_name: str = "Linf", + ) -> None: + + self.model = model + self.times = self.model.times + self.file = data_file + self.outdir = outdir + + # Load filters + if filters is None: + self.Filters = model.Filters + else: + self.Filters = [Filt for Filt in model.Filters if Filt.name in filters] + print(f"Loaded filters are: {[Filt.name for Filt in self.Filters]}.") + + # Load metric + if metric_name == "L2": + self.metric_name = "$\\mathcal{L}_2$" + self.metric = lambda y: np.sqrt(trapezoid(x= np.log(self.times) ,y=y**2, axis = -1)) / (np.log(self.times[-1]) - np.log(self.times[0])) + self.metric2d = lambda y: np.sqrt(trapezoid(x = self.nus, y =trapezoid(x = self.times, y = (y**2).reshape(-1, len(self.nus), len(self.times)) ) )) + self.file_ending = "L2" + else: + self.metric_name = "$\\mathcal{L}_\\inf$" + self.metric = lambda y: np.max(np.abs(y), axis = -1) + self.metric2d = lambda y: np.max(np.abs(y), axis = (1,2)) + self.file_ending = "Linf" + + self.get_data() + self.calculate_error() + self.get_error_distribution() + + def get_data(self,): + + # get the test data + self.test_mag = {} + with h5py.File(self.file, "r") as f: + self.parameter_distributions = ast.literal_eval(f["parameter_distributions"][()].decode('utf-8')) + self.parameter_names = f["parameter_names"][:].astype(str).tolist() + nus = f["nus"][:] + + self.test_X_raw = f["test"]["X"][:] + test_y_raw = f["test"]["y"][:] + test_y_raw = test_y_raw.reshape(len(self.test_X_raw), len(f["nus"]), len(f["times"]) ) + test_y_raw = interp1d(f["times"][:], test_y_raw, axis = 2)(self.times) # interpolate the test data over the time range of the model + mJys = np.exp(test_y_raw) + + for Filt in self.Filters: + self.test_mag[Filt.name] = Filt.get_mags(mJys, nus) + + # get the model prediction on the test data + param_dict = dict(zip(self.parameter_names, self.test_X_raw.T)) + param_dict["luminosity_distance"] = np.ones(len(self.test_X_raw)) * 1e-5 + param_dict["redshift"] = np.zeros(len(self.test_X_raw)) + _, self.pred_mag = self.model.vpredict(param_dict) + + def calculate_error(self,): + self.error = {} + + for Filt in self.Filters: + test_y = self.test_mag[Filt.name] + pred_y = self.pred_mag[Filt.name] + self.error[Filt.name] = self.metric(test_y - pred_y) + + if isinstance(self.model, FluxModel): + self.nus = self.model.nus + log_mJys = np.array([self.model.predict_log_flux(self.test_X_raw[j]) for j in range(len(self.test_X_raw))]) + self.error["total"] = self.metric2d(log_mJys) + else: + max_errors = {key: np.max(value) for key, value in self.error.items()} + max_key = max(max_errors, key=max_errors.get) + self.error["total"] = self.error[max_key] + + def get_error_distribution(self,): + error_distribution = {} + for j, p in enumerate(self.parameter_names): + p_array = self.test_X_raw[:,j] + bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12) + # calculate the error histogram with mismatch as weights + error_distribution[p] = np.histogram(p_array, weights = self.error["total"], bins = bins, density = True) + + self.error_distribution = error_distribution + + def benchmark(self,): + self.print_correlations() + self.plot_worst_lightcurves() + self.plot_error_over_time() + self.plot_error_distribution() + + def plot_lightcurves_mismatch(self, + parameter_labels: list[str] = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"] + ): + if self.metric_name == "$\\mathcal{L}_2$": + vline = self.metric(np.ones(len(self.times))) + vmin, vmax = 0, vline*2 + bins = np.linspace(vmin, vmax, 25) + else: + vline = 1. + vmin, vmax = 0, 2*vline + bins = np.linspace(vmin, vmax, 20) + + cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) + label_dic = {p: label for p, label in zip(self.parameter_names, parameter_labels)} + + for Filt in self.Filters: + + mismatch = self.error[Filt.name] + colored_mismatch = cmap(mismatch/vmax) + + + fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1) + fig.suptitle(f"{Filt.name}: {self.metric_name} norm") + + for j, p in enumerate(self.parameter_names[1:]): + for k, pp in enumerate(self.parameter_names[:j+1]): + sort = np.argsort(mismatch) + + ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1, rasterized = True) + + ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max())) + ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max())) + + + if k!=0: + ax[j,k].set_yticklabels([]) + + if j!=len(self.parameter_names)-2: + ax[j,k].set_xticklabels([]) + + ax[-1,k].set_xlabel(label_dic[pp]) + ax[j,0].set_ylabel(label_dic[p]) + + for cax in ax[j, j+1:]: + cax.set_axis_off() + + ax[0,-1].set_axis_on() + ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,) + ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed") + ax[0,-1].set_yticks([]) + + fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1]) + outfile = f"benchmark_{Filt.name}_{self.file_ending}.pdf" + + fig.savefig(os.path.join(self.outdir, outfile)) + + def plot_worst_lightcurves(self,): + + fig, ax = plt.subplots(len(self.Filters) , 1, figsize = (5, 15)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) + + for cax, filt in zip(ax, self.Filters): + ind = np.argmax(self.error[filt.name]) + prediction = self.pred_mag[filt.name][ind] + cax.plot(self.times, prediction, color = "blue") + cax.fill_between(self.times, prediction-1, prediction+1, color = "blue", alpha = 0.2) + + cax.plot(self.times, self.test_mag[filt.name][ind], color = "red") + cax.invert_yaxis() + cax.set(xlabel = "$t$ in days", ylabel = "mag", xscale = "log", xlim = (self.times[0], self.times[-1])) + cax.set_title(f"{filt.name}", loc = "right", pad = -20) + cax.text(0, 0.05, np.array_str(self.test_X_raw[ind], precision = 2), transform = cax.transAxes, fontsize = 7) + + fig.savefig(os.path.join(self.outdir, f"worst_lightcurves_{self.file_ending}.pdf"), dpi = 200) + + def plot_error_over_time(self,): + + fig, ax = plt.subplots(len(self.Filters) , 1, figsize = (5, 15)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) + + for cax, filt in zip(ax, self.Filters): + error = np.abs(self.pred_mag[filt.name] - self.test_mag[filt.name]) + indices = np.linspace(5, len(self.times)-1, 10).astype(int) + cax.violinplot(error[:, indices], positions = self.times[indices], widths = self.times[indices]/3) + cax.set(xlabel = "$t$ in days", ylabel = "error in mag", xscale = "log", xlim = (self.times[0], self.times[-1]), ylim = (0,1.5)) + cax.set_title(f"{filt.name}", loc = "right", pad = -20) + + fig.savefig(os.path.join(self.outdir, f"error_over_time.pdf"), dpi = 200) + + def print_correlations(self, ): + for Filt in self.Filters: + error = self.error[Filt.name] + print(f"\n \n \nCorrelations for filter {Filt.name}:\n") + for j, p in enumerate(self.parameter_names): + print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], error)[0,1]}") + + def plot_error_distribution(self,): + + fig, ax = plt.subplots(len(self.parameter_names), 1, figsize = (4, 18)) + fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.09, right = 0.95) + + for p, cax in zip(self.parameter_names, ax): + cax.bar(self.error_distribution[p][1][:-1], self.error_distribution[p][0], width = 1, color = "blue") + cax.set_xlabel(p) + cax.set_yticks([]) + + fig.savefig(os.path.join(self.outdir, f"error_distribution.pdf"), dpi = 200) diff --git a/src/fiesta/train/BenchmarkerFluxes.py b/src/fiesta/train/BenchmarkerFluxes.py deleted file mode 100644 index 5b5e70a..0000000 --- a/src/fiesta/train/BenchmarkerFluxes.py +++ /dev/null @@ -1,245 +0,0 @@ -from fiesta.inference.lightcurve_model import LightcurveModel -import afterglowpy as grb -from fiesta.constants import days_to_seconds, c -from fiesta import conversions -from fiesta import utils - -from jaxtyping import Array, Float, Int - -import tqdm -import os -import ast -import h5py -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.colors as colors -from matplotlib.cm import ScalarMappable - -from scipy.integrate import trapezoid -from scipy.interpolate import interp1d - -class Benchmarker: - - def __init__(self, - name: str, - model_dir: str, - filters: list[str], - MODEL = LightcurveModel, - metric_name: str = "$\\mathcal{L}_\\inf$" - ) -> None: - - self.name = name - self.model_dir = model_dir - self.model = MODEL(name = self.name, - directory = self.model_dir, - filters = filters) - - self.times = self.model.times - self.nus = self.model.metadata["nus"] - self.load_filters(filters) - - self.get_test_data() - self.lightcurve_test_data() - - - self.metric_name = metric_name - if metric_name == "$\\mathcal{L}_2$": - self.metric = lambda y: np.sqrt(trapezoid(x= self.times,y=y**2)) - self.metric2d = lambda y: np.sqrt(trapezoid(x = self.nus, y =trapezoid(x = self.times, y = (y**2).reshape(len(self.nus), len(self.times)) ) )) - else: - self.metric = lambda y: np.max(np.abs(y)) - self.metric2d = self.metric - - self.calculate_mismatch() - self.get_error_distribution() - - def __repr__(self) -> str: - return f"Surrogte_Benchmarker(name={self.name}, model_dir={self.model_dir})" - - def load_filters(self, filters: list[str]): - self.filters = [] - for filter in filters: - try: - self.filters.append(utils.Filter(filter)) - except: - raise Exception(f"Filter {filter} not available.") - - def get_test_data(self,): - - file = [f for f in os.listdir(self.model_dir) if f.endswith("_raw_data.h5")][0] - - with h5py.File(os.path.join(self.model_dir, file), "r") as f: - self.parameter_distributions = ast.literal_eval(f["parameter_distributions"][()].decode('utf-8')) - self.parameter_names = f["parameter_names"][:].astype(str).tolist() - self.test_X_raw = f["test"]["X"][:] - y_raw = f["test"]["y"][:] - y_raw = y_raw.reshape(len(self.test_X_raw), len(f["nus"]), len(f["times"]) ) - y_raw = interp1d(f["times"][:], y_raw, axis = 2)(self.times) # interpolate the test data over the time range of the model - y_raw = interp1d(f["nus"][:], y_raw, axis = 1)(self.nus) # interpolate the test data over the frequency range of the model - self.fluxes_raw = y_raw.reshape(len(self.test_X_raw), len(self.nus) * len(self.times) ) - - def lightcurve_test_data(self, ): - - self.n_test_data = len(self.test_X_raw) - self.prediction_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} - self.test_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} - self.prediction_log_fluxes = np.empty((self.n_test_data, len(self.nus) * len(self.times))) - for j, X in enumerate(self.test_X_raw): - param_dict = {name: x for name, x in zip(self.parameter_names, X)} - prediction = self.model.predict(param_dict) - self.prediction_log_fluxes[j] = self.model.predict_log_flux(X) - for filt in self.filters: - self.prediction_y_raw[filt.name][j] = prediction[filt.name] - self.test_y_raw[filt.name][j] = self.convert_to_mag(filt.nu, self.fluxes_raw[j]) - - def convert_to_mag(self, nu, flux): - flux = flux.reshape(len(self.model.metadata["nus"]), len(self.model.times)) - flux = np.exp(flux) - flux = np.array([np.interp(nu, self.model.metadata["nus"], column) for column in flux.T]) - mag = -48.6 + -1 * np.log10(flux*1e-3 / 1e23) * 2.5 - return mag - - - ### Diagnostics ### - - def calculate_mismatch(self): - mismatch = {} - for filt in self.filters: - array = np.empty(self.n_test_data) - for j in range(self.n_test_data): - array[j] = self.metric(self.prediction_y_raw[filt.name][j] - self.test_y_raw[filt.name][j]) - mismatch[filt.name] = array - - array = np.empty(self.n_test_data) - for j in range(self.n_test_data): - array[j] = self.metric2d(self.prediction_log_fluxes[j]-self.fluxes_raw[j]) - - mismatch["total"] = array - self.mismatch = mismatch - - def plot_lightcurves_mismatch(self, - filter: str, - parameter_labels: list[str] = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"] - ): - if self.metric_name == "$\\mathcal{L}_2$": - vline = np.sqrt(trapezoid(x = self.times, y = 0.2*np.ones(len(self.times)))) - vmin, vmax = 0, vline*2 - bins = np.linspace(vmin, vmax, 25) - else: - vline = 1. - vmin, vmax = 0, 2*vline - bins = np.linspace(vmin, vmax, 20) - - mismatch = self.mismatch[filter] - - cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) - colored_mismatch = cmap(mismatch/vmax) - - label_dic = {p: label for p, label in zip(self.parameter_names, parameter_labels)} - - fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1) - fig.suptitle(f"{filter}: {self.metric_name} norm") - - for j, p in enumerate(self.parameter_names[1:]): - for k, pp in enumerate(self.parameter_names[:j+1]): - sort = np.argsort(mismatch) - - ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1, rasterized = True) - - ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max())) - ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max())) - - - if k!=0: - ax[j,k].set_yticklabels([]) - - if j!=len(self.parameter_names)-2: - ax[j,k].set_xticklabels([]) - - ax[-1,k].set_xlabel(label_dic[pp]) - ax[j,0].set_ylabel(label_dic[p]) - - for cax in ax[j, j+1:]: - cax.set_axis_off() - - ax[0,-1].set_axis_on() - ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,) - ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed") - ax[0,-1].set_yticks([]) - - fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1]) - return fig, ax - - def print_correlations(self, - filter: str,): - mismatch = self.mismatch[filter] - print(f"\n \n \nCorrelations for filter {filter}:\n") - corrcoeff = [] - for j, p in enumerate(self.parameter_names): - print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], mismatch)[0,1]}") - - def get_error_distribution(self): - error_distribution = {} - for j, p in enumerate(self.parameter_names): - p_array = self.test_X_raw[:,j] - #bins = (np.array(self.parameter_distributions[p][:-1]) + np.array(self.parameter_grid[p][1:]))/2 - #bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] - bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12) - # calculate the error histogram with mismatch as weights - error_distribution[p], _ = np.histogram(p_array, weights = self.mismatch["total"], bins = bins, density = True) - error_distribution[p] = error_distribution[p]/np.sum(error_distribution[p]) - - self.error_distribution = error_distribution - - - def plot_worst_lightcurves(self,): - - fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) - fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) - - for cax, filt in zip(ax, self.filters): - ind = np.argmax(self.mismatch[filt.name]) - prediction = self.prediction_y_raw[filt.name][ind] - cax.plot(self.times, prediction, color = "blue") - cax.fill_between(self.times, prediction-1, prediction+1, color = "blue", alpha = 0.2) - cax.plot(self.times, self.test_y_raw[filt.name][ind], color = "red") - cax.invert_yaxis() - cax.set(xlabel = "$t$ in days", ylabel = "mag", xscale = "log", xlim = (self.times[0], self.times[-1])) - cax.set_title(f"{filt.name}", loc = "right", pad = -20) - cax.text(0, 0.05, np.array_str(self.test_X_raw[ind], precision = 2), transform = cax.transAxes, fontsize = 7) - - return fig, ax - - def plot_error_over_time(self,): - - fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) - fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) - - for cax, filt in zip(ax, self.filters): - error = np.abs(self.prediction_y_raw[filt.name] - self.test_y_raw[filt.name]) - indices = np.linspace(5, len(self.times)-1, 10).astype(int) - cax.violinplot(error[:, indices], positions = self.times[indices], widths = self.times[indices]/3) - cax.set(xlabel = "$t$ in days", ylabel = "error in mag", xscale = "log", xlim = (self.times[0], self.times[-1]), ylim = (0,1.5)) - cax.set_title(f"{filt.name}", loc = "right", pad = -20) - return fig, ax - - def plot_error_distribution(self,): - mismatch = self.mismatch["total"] - - fig, ax = plt.subplots(len(self.parameter_names), 1, figsize = (4, 18)) - fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.09, right = 0.95) - - for j, cax in enumerate(ax): - p_array = self.test_X_raw[:,j] - p = self.parameter_names[j] - bins = np.linspace(self.parameter_distributions[p][0], self.parameter_distributions[p][1], 12) - - cax.hist(p_array, weights = self.mismatch["total"], color = "blue", bins = bins, density = True, histtype = "step") - cax.set_xlabel(self.parameter_names[j]) - cax.set_yticks([]) - - return fig, ax - - - - diff --git a/src/fiesta/train/BenchmarkerLightcurves.py b/src/fiesta/train/BenchmarkerLightcurves.py deleted file mode 100644 index dd6b117..0000000 --- a/src/fiesta/train/BenchmarkerLightcurves.py +++ /dev/null @@ -1,293 +0,0 @@ -from fiesta.inference.lightcurve_model import LightcurveModel -import afterglowpy as grb -from fiesta.constants import days_to_seconds -from fiesta import conversions -from fiesta import utils -from fiesta.utils import Filter - -from jaxtyping import Array, Float - -import tqdm -import os -import numpy as np -import matplotlib.pyplot as plt -import matplotlib.colors as colors -from matplotlib.cm import ScalarMappable - -from scipy.integrate import trapezoid - -# TODO: get a benchmarker class for all surrogate model -class Benchmarker: - - name: str - model_dir: str - filters: list[Filter] - n_test_data: int - metric_name: str - jet_type: int - model: AfterglowpyLightcurvemodel - - def __init__(self, - name: str, - model_dir: str, - filters: list[str], - parameter_grid: dict, - MODEL = LightcurveModel, - n_test_data: int = 3000, - remake_test_data: bool = False, - metric_name: str = "$\\mathcal{L}_\\inf$", - jet_type = -1 - ) -> None: - - self.name = name - self.model_dir = model_dir - self.load_filters(filters) - self.model = MODEL(name = self.name, - directory = self.model_dir, - filters = filters) - self.times = self.model.times - self._times_afterglowpy = self.times * days_to_seconds - self.jet_type = jet_type - - self.parameter_names = self.model.metadata["parameter_names"] - self.parameter_grid = parameter_grid - - if os.path.exists(self.model_dir+"/raw_data_test.npz") and not remake_test_data: - self.load_test_data() - else: - self.get_test_data(n_test_data) - - self.metric_name = metric_name - mask = np.logical_and(self.times>1, self.times<1000) - if metric_name == "$\\mathcal{L}_2$": - self.metric = lambda y: np.sqrt(trapezoid(x= self.times[mask],y=y[mask]**2)) - else: - self.metric = lambda y: np.max(np.abs(y[mask])) - - self.calculate_mismatch() - self.get_error_distribution() - - def __repr__(self) -> str: - return f"Surrogate_Benchmarker(name={self.name}, model_dir={self.model_dir})" - - def load_filters(self, filters: list[str]): - self.filters = [] - for filter in filters: - try: - self.filters.append(utils.Filter(filter)) - except: - raise Exception(f"Filter {filter} not available.") - - def get_test_data(self, n_test_data): - test_X_raw = np.empty((n_test_data, len(self.parameter_names))) - test_y_raw = {filter.name: np.empty((n_test_data, len(self.times))) for filter in self.filters} - prediction_y_raw = {filter.name: np.empty((n_test_data, len(self.times))) for filter in self.filters} - - print(f"Determining test data for {n_test_data} random points within parameter grid.") - for j in tqdm.tqdm(range(n_test_data)): - test_X_raw[j] = np.random.uniform(low = [self.parameter_grid[p][0] for p in self.parameter_names], high = [self.parameter_grid[p][-1] for p in self.parameter_names]) - param_dict = {name: x for name, x in zip(self.parameter_names, test_X_raw[j])} - - prediction = self.model.predict(param_dict) - - for filt in self.filters: - param_dict["nu"] = filt.nu - prediction_y_raw[filt.name][j] = prediction[filt.name] - mJys = self._call_afterglowpy(param_dict) - test_y_raw[filt.name][j] = conversions.mJys_to_mag_np(mJys) - - self.test_X_raw = test_X_raw - self.test_y_raw = test_y_raw - self.prediction_y_raw = prediction_y_raw - self.n_test_data = n_test_data - - #for saving - test_saver = {"test_"+key: test_y_raw[key] for key in test_y_raw.keys()} - np.savez(os.path.join(self.model_dir, "raw_data_test.npz"), X = test_X_raw, **test_saver) - - def load_test_data(self, ): - - test_data = np.load(self.model_dir+"/raw_data_test.npz") - self.test_X_raw = test_data["X"] - self.test_y_raw = {filt.name: test_data["test_"+filt.name] for filt in self.filters} - self.n_test_data = len(self.test_X_raw) - - self.prediction_y_raw = {filt.name: np.empty((self.n_test_data, len(self.times))) for filt in self.filters} - for j, X in enumerate(self.test_X_raw): - param_dict = {name: x for name, x in zip(self.parameter_names, X)} - prediction = self.model.predict(param_dict) - for filt in self.filters: - self.prediction_y_raw[filt.name][j] = prediction[filt.name] - - def _call_afterglowpy(self, - params_dict: dict[str, Float]) -> Float[Array, "n_times"]: - """ - Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. - The output will be a set of mJys. - - Args: - Float[Array, "n_times"]: The flux density in mJys at the given times. - """ - - # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z - Z = {} - - Z["jetType"] = params_dict.get("jetType", self.jet_type) - Z["specType"] = params_dict.get("specType", 0) - Z["z"] = params_dict.get("z", 0.0) - Z["xi_N"] = params_dict.get("xi_N", 1.0) - - Z["E0"] = 10 ** params_dict["log10_E0"] - Z["n0"] = 10 ** params_dict["log10_n0"] - Z["p"] = params_dict["p"] - Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] - Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] - Z["d_L"] = 3.086e19 # fix at 10 pc, so that AB magnitude equals absolute magnitude - if "inclination_EM" in list(params_dict.keys()): - Z["thetaObs"] = params_dict["inclination_EM"] - else: - Z["thetaObs"] = params_dict["thetaObs"] - - if self.jet_type == -1: - Z["thetaCore"] = params_dict["thetaCore"] - - if "thetaWing" in list(params_dict.keys()): #for Gaussian and power law jets - Z["thetaWing"] = params_dict["thetaWing"] - Z["thetaCore"] = params_dict["xCore"]*params_dict["thetaWing"] - - if self.jet_type == 4: - Z["b"] = params_dict["b"] - - # Afterglowpy returns flux in mJys - mJys = grb.fluxDensity(self._times_afterglowpy, params_dict["nu"], **Z) - return mJys - - def calculate_mismatch(self): - mismatch = {} - for filt in self.filters: - array = np.empty(self.n_test_data) - for j in range(self.n_test_data): - array[j] = self.metric(self.prediction_y_raw[filt.name][j] - self.test_y_raw[filt.name][j]) - mismatch[filt.name] = array - - self.mismatch = mismatch - - def plot_lightcurves_mismatch(self, - filter: str, - parameter_labels: list[str] = ["$\\iota$", "$\log_{10}(E_0)$", "$\\theta_c$", "$\log_{10}(n_{\mathrm{ism}})$", "$p$", "$\\epsilon_E$", "$\\epsilon_B$"] - ): - if self.metric_name == "$\\mathcal{L}_2$": - bins = np.arange(0, 100, 5) - vmin, vmax = 0, 50 - vline = np.sqrt(trapezoid(x = self.times, y = np.ones(len(self.times)))) - else: - bins = np.arange(0, 3, 0.5) - vmin, vmax = 0, 2 - vline = 1. - - mismatch = self.mismatch[filter] - - cmap = colors.LinearSegmentedColormap.from_list(name = "mymap", colors = [(0, "lightblue"), (1, "darkred")]) - colored_mismatch = cmap(mismatch/vmax) - - label_dic = {p: label for p, label in zip(self.parameter_names, parameter_labels)} - - fig, ax = plt.subplots(len(self.parameter_names)-1, len(self.parameter_names)-1) - fig.suptitle(f"{filter}: {self.metric_name} norm") - - for j, p in enumerate(self.parameter_names[1:]): - for k, pp in enumerate(self.parameter_names[:j+1]): - sort = np.argsort(mismatch) - - ax[j,k].scatter(self.test_X_raw[sort,k], self.test_X_raw[sort,j+1], c = colored_mismatch[sort], s = 1) - - ax[j,k].set_xlim((self.test_X_raw[:,k].min(), self.test_X_raw[:,k].max())) - ax[j,k].set_ylim((self.test_X_raw[:,j+1].min(), self.test_X_raw[:,j+1].max())) - - - if k!=0: - ax[j,k].set_yticklabels([]) - - if j!=len(self.parameter_names)-2: - ax[j,k].set_xticklabels([]) - - ax[-1,k].set_xlabel(label_dic[pp]) - ax[j,0].set_ylabel(label_dic[p]) - - for cax in ax[j, j+1:]: - cax.set_axis_off() - - ax[0,-1].set_axis_on() - ax[0,-1].hist(mismatch, density = True, histtype = "step", bins = bins,) - ax[0,-1].vlines([vline], *ax[0,-1].get_ylim(), colors = ["lightgrey"], linestyles = "dashed") - ax[0,-1].set_yticks([]) - - fig.colorbar(ScalarMappable(norm=colors.Normalize(vmin = vmin, vmax = vmax), cmap = cmap), ax = ax[1:-1, -1]) - return fig, ax - - def print_correlations(self, - filter: str,): - - mismatch = self.mismatch[filter] - - - print(f"\n \n \nCorrelations for filter {filter}:\n") - corrcoeff = [] - for j, p in enumerate(self.parameter_names): - print(f"{p}: {np.corrcoef(self.test_X_raw[:,j], mismatch)[0,1]}") - - def get_error_distribution(self): - - error_distribution = {filt.name: {} for filt in self.filters} - - for filt in self.filters: - for j, p in enumerate(self.parameter_names): - p_array = self.test_X_raw[:,j] - bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 - bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] - # calculate the error histogram with mismatch as weights - error_distribution[filt.name][p], _ = np.histogram(p_array, weights = self.mismatch[filt.name], bins = bins, density = True) - error_distribution[filt.name][p] = error_distribution[filt.name][p]/np.sum(error_distribution[filt.name][p]) - - self.error_distribution = error_distribution - - - def plot_worst_lightcurves(self,): - - fig, ax = plt.subplots(len(self.filters) , 1, figsize = (5, 15)) - fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.14, right = 0.95) - - for cax, filt in zip(ax, self.filters): - ind = np.argmax(self.mismatch[filt.name]) - prediction = self.prediction_y_raw[filt.name][ind] - cax.plot(self.times, prediction, color = "blue") - cax.fill_between(self.times, prediction-1, prediction+1, color = "blue", alpha = 0.2) - cax.plot(self.times, self.test_y_raw[filt.name][ind], color = "red") - cax.invert_yaxis() - cax.set(xlabel = "$t$ in days", ylabel = "mag") - cax.set_title(f"{filt.name}", loc = "right", pad = -20) - - return fig, ax - - def plot_error_distribution(self, filter): - mismatch = self.mismatch[filter] - - fig, ax = plt.subplots(len(self.parameter_names), 1, figsize = (4, 18)) - fig.subplots_adjust(hspace = 0.5, bottom = 0.08, top = 0.98, left = 0.09, right = 0.95) - - for j, cax in enumerate(ax): - p_array = self.test_X_raw[:,j] - p = self.parameter_names[j] - bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 - bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]] - - cax.hist(p_array, weights = self.mismatch[filter], color = "blue", bins = bins, density = True, histtype = "step") - cax.set_xlabel(self.parameter_names[j]) - cax.set_yticks([]) - - - return fig, ax - - - - diff --git a/src/fiesta/train/DataManager.py b/src/fiesta/train/DataManager.py new file mode 100644 index 0000000..70daeef --- /dev/null +++ b/src/fiesta/train/DataManager.py @@ -0,0 +1,325 @@ +from fiesta.utils import MinMaxScalerJax, StandardScalerJax, PCADecomposer, ImageScaler, SVDDecomposer + +import numpy as np +import jax.numpy as jnp +import h5py +import gc +from jaxtyping import Array, Float, Int + + +def array_mask_from_interval(sorted_array, amin, amax): + indmin = max(0, np.searchsorted(sorted_array, amin, side='right') -1) + indmax = min(len(sorted_array)-1, np.searchsorted(sorted_array, amax)) + mask = np.logical_and(sorted_array>=sorted_array[indmin], sorted_array<=sorted_array[indmax]) + return mask + + +################### +# DATA MANAGEMENT # +################### + +class DataManager: + + def __init__(self, + file: str, + n_training: Int, + n_val: Int, + tmin: Float, + tmax: Float, + numin: Float = 1e9, + numax: Float = 2.5e18, + special_training: list = [], + ) -> None: + """ + DataManager class used to handle and preprocess the raw data from the physical model computations stored in an .h5 file. + Initializing an instance of this class will only read in the meta data, the actual training data and validation data will only be loaded if one of the preprocessing methods is called. + + The .h5 file must contain the following data sets: + - "times": times in days associated to the spectral flux densities + - "nus": frequencies in Hz associated to the spectral flux densities + - "parameter_names": list of the parameter names that are present in the training data. + - "parameter_distributions": utf-8-string of a dict containing the boundaries and distribution of the parameters. + Additionally, it must contain three data groups "train", "val", "test". Each of these groups contains two data sets, namely "X" and "y". + The X arrays contain the model parameters with columns in the order of "parameter_names" and thus have shape (-1, #parameters). The y array contains the associated log of the spectral flux densities in mJys and have shape (-1, #nus * #times). + To get the full 2D log spectral flux density arrays, one needs to reshape 1D entries of y to (#nus, #times). + + Args: + file (str): Path to the .h5 file that contains the raw data. + n_training (int): Number of training data points that will be read in and preprocessed. If used with a FluxTrainer, this is also the number of training data points used to train the model. + Will raise a ValueError, if n_training is larger than the number of training data points stored in the file. + n_val (int): Number of validation data points that will be read in and preprocessed. If used with a FluxTrainer, this is also the number of validation data points used to monitor the training progress. + Will raise a ValueError, if n_val is larger than the number of validation data points stored in the file. + tmin (float): Minimum time for which the data will be read in. Fluxes earlier than this time will not be loaded. Defaults to the minimum time of the stored data, if smaller than that value. + max (float): Maximum time for which the data will be read in. Fluxes later than this time will not be loaded. Defaults to the maximum time of the stored data, if larger than that value. + numin (float): Minimum frequency for which the data will be read in. Fluxes with frequencies lower than this frequency will not be loaded. Defaults to the minimum frequency of the stored data, if smaller than that value. + numax (float): Maximum frequency for which the data will be read in. Fluxes with frequencies higher than this frequency will not be loaded. Defaults to the maximum frequency of the stored data, if larger than that value. Defaults to 1e9 Hz (1 GHz). + special_training (list[str]): Batch of 'special' training data to be added. This can be customly designed training data to cover a certain area of the parameter space more intensily and should be stored in the .h5 file as f['special_train'][label]['X'] and f['special_train'][label]['y'], where label is an entry in this special_training. Defaults to []. + """ + + self.file = file + self.n_training = n_training + self.n_val = n_val + + self.tmin = tmin + self.tmax = tmax + self.numin = numin + self.numax = numax + + self.special_training = special_training + + self.read_metadata_from_file() + self.set_up_domain_mask() + + def read_metadata_from_file(self,)->None: + """ + Reads in the metadata of the raw data, i.e., times, frequencies and parameter names. + Also determines how many training and validation data points are available. + """ + with h5py.File(self.file, "r") as f: + self.times_data = f["times"][:] + self.nus_data = f["nus"][:] + self.parameter_names = f["parameter_names"][:].astype(str).tolist() + self.n_training_exists = f["train"]["X"].shape[0] + self.n_val_exists = f["val"]["X"].shape[0] + self.parameter_distributions = f['parameter_distributions'][()].decode('utf-8') + + # check if there is enough data + if self.n_training > self.n_training_exists: + raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") + if self.n_val > self.n_val_exists: + raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to train with {self.n_val} data points.") + + def set_up_domain_mask(self,)->None: + """Trims the stored data down to the time and frequency range desired for training. It sets the mask attribute which is a boolean mask used when loading the data arrays.""" + + if self.tminself.times_data.max(): + print(f"\nWarning: provided time range {self.tmin, self.tmax} is too wide for the data stored in file. Using range {max(self.times_data.min(), self.tmin), min(self.times_data.max(), self.tmax)} instead.\n") + time_mask = array_mask_from_interval(self.times_data, self.tmin, self.tmax) + self.times = self.times_data[time_mask] + self.n_times = len(self.times) + + if self.numinself.nus_data.max(): + print(f"\nWarning: provided frequency range {self.numin, self.numax} is too wide for the data stored in file. Using range {max(self.nus_data.min(), self.numin), min(self.nus_data.max(), self.numax)} instead.\n") + nu_mask = array_mask_from_interval(self.nus_data, self.numin, self.numax) + self.nus = self.nus_data[nu_mask] + self.n_nus = len(self.nus) + + mask = nu_mask[:, None] & time_mask + self.mask = mask.flatten() + + def print_file_info(self,) -> None: + """ + Prints the meta data of the raw data, i.e., time, frequencies, and parameter names to terminal. + Also prints how many training, validation, and test data points are available. + """ + with h5py.File(self.file, "r") as f: + print(f"Times: {f['times'][0]} {f['times'][-1]}") + print(f"Nus: {f['nus'][0]} {f['nus'][-1]}") + print(f"Parameter distributions: {f['parameter_distributions'][()].decode('utf-8')}") + print("\n") + print(f"Training data: {self.n_training_exists}") + print(f"Validation data: {self.n_val_exists}") + print(f"Test data: {f['test']['X'].shape[0]}") + print("Special data:") + for key in f['special_train'].keys(): + print(f"\t {key}: {f['special_train'][key]['X'].shape[0]} description: {f['special_train'][key].attrs['comment']}") + print("\n \n") + + def load_raw_data_from_file(self,) -> None: + """Loads raw data for training and validation as attributes to the instance.""" + with h5py.File(self.file, "r") as f: + if self.n_training>self.n_training_exists: + raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") + self.train_X_raw = f["train"]["X"][:self.n_training] + self.train_y_raw = f["train"]["y"][:self.n_training, self.mask] + + for label in self.special_training: + self.train_X_raw = np.concatenate((self.train_X_raw, f["special_train"][label]["X"][:])) + self.train_y_raw = np.concatenate((self.train_y_raw, f["special_train"][label]["y"][:, self.mask])) + + if self.n_val>self.n_val_exists: + raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to validate with {self.n_val} data points.") + self.val_X_raw = f["val"]["X"][:self.n_val] + self.val_y_raw = f["val"]["y"][:self.n_val, self.mask] + + def preprocess_pca(self, n_components: int) -> tuple[Array, Array, Array, Array, object, object]: + """ + Loads in the training and validation data and performs PCA decomposition using fiesta.utils.PCADecomposer. + Because of memory issues, the training data set is loaded in chunks. + The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax. + + Args: + n_components(int): Number of PCA components to keep. + Returns: + train_X (Array): Standardized training parameters. + train_y (Array): PCA coefficients of the training data. + val_X (Array): Standardized validation parameters + val_y (Array): PCA coefficients of the validation data. + Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. + yscaler (PCAdecomposer): PCADecomposer object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities. + """ + Xscaler, yscaler = StandardScalerJax(), PCADecomposer(n_components=n_components) + + # preprocess the training data + with h5py.File(self.file, "r") as f: + train_X_raw = f["train"]["X"][:self.n_training] + train_X = Xscaler.fit_transform(train_X_raw) # fit the Xscaler and transform the train_X_raw + + y_set = f["train"]["y"] + + loaded = y_set[: min(20_000, self.n_training), self.mask].astype(np.float16) # only load max. 20k cause otherwise we might run out of memory at this step + assert not np.any(np.isinf(loaded)), f"Found inftys in training data." + yscaler.fit(loaded) # fit the yscaler and transform with the loaded data + del loaded; gc.collect() # remove loaded from memory + + train_y = np.empty((self.n_training, n_components)) + + chunk_size = y_set.chunks[0] # load raw data in chunks of chunk_size + nchunks, rest = divmod(self.n_training, chunk_size) # load raw data in chunks of chunk_size + for j, chunk in enumerate(y_set.iter_chunks()): + loaded = y_set[chunk][:, self.mask] + assert not np.any(np.isinf(loaded)), f"Found inftys in training data." + train_y[j*chunk_size:(j+1)*chunk_size] = yscaler.transform(loaded) + if j>= nchunks-1: + break + if rest > 0: + loaded = y_set[-rest:, self.mask] + assert not np.any(np.isinf(loaded)), f"Found inftys in training data." + train_y[-rest:] = yscaler.transform(loaded) + + # preprocess the special training data as well ass the validation data + train_X, train_y, val_X, val_y = self.__preprocess__special_and_val_data(train_X, train_y, Xscaler, yscaler) + + return train_X, train_y, val_X, val_y, Xscaler, yscaler + + def preprocess_cVAE(self, image_size: Int[Array, "shape=(2,)"]) -> tuple[Array, Array, Array, Array, object, object]: + """ + Loads in the training and validation data and performs data preprocessing for the CVAE using fiesta.utils.ImageScaler. + Because of memory issues, the training data set is loaded in chunks. + The X arrays (parameter values) are standardized with fiesta.utils.StandardScalerJax. + + Args: + image_size (Array[Int]): Image size the 2D flux arrays are down sampled to with jax.image.resize + Returns: + train_X (Array): Standardized training parameters. + train_y (Array): PCA coefficients of the training data. + val_X (Array): Standardized validation parameters + val_y (Array): PCA coefficients of the validation data. + Xscaler (StandardScalerJax): Standardizer object fitted to the mean and sigma of the raw training data. Can be used to transform and inverse transform parameter points. + yscaler (ImageScaler): ImageScaler object fitted to part of the raw training data. Can be used to transform and inverse transform log spectral flux densities. + """ + Xscaler, yscaler = StandardScalerJax(), ImageScaler(downscale = image_size, upscale = (self.n_nus, self.n_times), scaler = StandardScalerJax()) + + # preprocess the training data + with h5py.File(self.file, "r") as f: + train_X_raw = f["train"]["X"][:self.n_training] + train_X = Xscaler.fit_transform(train_X_raw) # fit the Xscaler and transform the train_X_raw + + y_set = f["train"]["y"] + + train_y = np.empty((self.n_training, jnp.prod(image_size)), dtype=jnp.float16) + + chunk_size = y_set.chunks[0] + nchunks, rest = divmod(self.n_training, chunk_size) # create raw data in chunks of chunk_size + for j, chunk in enumerate(y_set.iter_chunks()): + loaded = y_set[chunk][:, self.mask].astype(jnp.float16) + assert not np.any(np.isinf(loaded)), f"Found inftys in training data." + train_y[j*chunk_size:(j+1)*chunk_size] = yscaler.resize_image(loaded).reshape(-1, jnp.prod(image_size)) + if j>= nchunks-1: + break + if rest > 0: + loaded = y_set[-rest:, self.mask].astype(jnp.float16) + assert not np.any(np.isinf(loaded)), f"Found inftys in training data." + train_y[-rest:] = yscaler.resize_image(loaded).reshape(-1, jnp.prod(image_size)) + + train_y = yscaler.fit_transform_scaler(train_y) # this standardizes now the down sampled fluxes + + # preprocess the special training data as well ass the validation data + train_X, train_y, val_X, val_y = self.__preprocess__special_and_val_data(train_X, train_y, Xscaler, yscaler) + return train_X, train_y, val_X, val_y, Xscaler, yscaler + + + def __preprocess__special_and_val_data(self, train_X, train_y, Xscaler, yscaler) -> tuple[Array, Array, Array, Array]: + """ sub method that just applies the scaling transforms to the validation and special training data """ + with h5py.File(self.file, "r") as f: + # preprocess the special training data + for label in self.special_training: + special_train_X = Xscaler.transform(f["special_train"][label]["X"][:]) + train_X = np.concatenate((train_X, special_train_X)) + + special_train_y = yscaler.transform(f["special_train"][label]["y"][:, self.mask]) + train_y = np.concatenate(( train_y, special_train_y.astype(jnp.float16) )) + + # preprocess validation data + val_X_raw = f["val"]["X"][:self.n_val] + val_X = Xscaler.transform(val_X_raw) + val_y_raw = f["val"]["y"][:self.n_val, self.mask] + val_y = yscaler.transform(val_y_raw) + + return train_X, train_y, val_X, val_y + + def preprocess_svd(self, svd_ncoeff: Int, filters: list) -> tuple[Array, dict[Array], Array, dict[Array], object, dict[object]]: + """ + Loads in the training and validation data and performs data preprocessing for the SVD decomposition using fiesta.utils.SVDDecomposer. + This is done *per filter* supplied in the filters argument which is equivalent to the old NMMA procedure. + The X arrays (parameter values) are scaled to [0,1] with MinMaxScalerJax() + + Args: + svd_ncoeff (Int): Number of SVD coefficients to keep + filters (Filter[list]): List of fiesta.utils.filter instances that are used to convert the fluxes to magnitudes + Returns: + train_X (Array): Scaled training parameters. + train_y (dict[Array]): Dictionary of the SVD coefficients of the training magnitude lightcurves with the filter names as keys + val_X (Array): Scaled validation parameters + val_y (dict[Array]): Dictionary of the SVD coefficients of the validation magnitude lightcurves with the filter names as keys + Xscaler (MinMaxScalerJax): MinMaxScaler object fitted to the minimum and maximum of the training data parameters. Can be used to transform and inverse transform parameter points. + yscaler (dict[SVDDecomposer]): Dictionary of SVDDecomposer objects with the filter names as keys. The SVDDecomposer objects are fitted to the magnitude training data. Can be used to transform and inverse transform magnitudes in this filter. + """ + #TODO: dealing with redshift at this step + Xscaler, yscaler = MinMaxScalerJax(), {filt.name: SVDDecomposer(svd_ncoeff) for filt in filters} + train_y = {} + val_y = {} + + # preprocess the training data + with h5py.File(self.file, "r") as f: + train_X_raw = f["train"]["X"][:self.n_training] + train_X = Xscaler.fit_transform(train_X_raw) # fit the Xscaler and transform the train_X_raw + + for label in self.special_training: + special_train_X = Xscaler.transform(f["special_train"[label]["X"][:]]) + train_X = np.concatenate((train_X, special_train_X)) + + val_X_raw = f["val"]["X"][:self.n_val] + val_X = Xscaler.transform(val_X_raw) + + train_y_raw = f["train"]["y"][:, self.mask].reshape(-1, self.n_nus, self.n_times) + mJys_train = np.exp(train_y_raw) + val_y_raw = f["val"]["y"][:self.n_val, self.mask].reshape(-1, self.n_nus, self.n_times) + mJys_val = np.exp(val_y_raw) + + for filt in filters: + mag = filt.get_mags(mJys_train, self.nus) # convert to magnitudes + train_data = yscaler[filt.name].fit_transform(mag) + + # preprocess the special training data + for label in self.special_training: + special_train_y = np.exp(f["special_train"][label]["y"][:, self.mask].reshape(-1, self.n_nus, self.n_times)) + special_mag = filt.get_mags(special_train_y, self.nus) # convert to magnitudes + special_train_data = yscaler[filt.name].transform(special_mag) + train_data = np.concatenate((train_data, special_train_data)) + + train_y[filt.name] = train_data + + # preprocess validation data + mag = filt.get_mags(mJys_val, self.nus) # convert to magnitudes + val_data = yscaler[filt.name].transform(mag) + val_y[filt.name] = val_data + + return train_X, train_y, val_X, val_y, Xscaler, yscaler + + def pass_meta_data(self, object) -> None: + """Pass training data meta data to another object. Used for the FluxTrainers.""" + object.parameter_names = self.parameter_names + object.times = self.times + object.nus = self.nus + object.parameter_distributions = self.parameter_distributions \ No newline at end of file diff --git a/src/fiesta/train/FluxTrainer.py b/src/fiesta/train/FluxTrainer.py index 5943307..90366cd 100644 --- a/src/fiesta/train/FluxTrainer.py +++ b/src/fiesta/train/FluxTrainer.py @@ -1,333 +1,262 @@ """Method to train the surrogate models""" +import fiesta.train.neuralnets as fiesta_nn +from fiesta.train.DataManager import DataManager + import os import numpy as np import jax -import jax.numpy as jnp from jaxtyping import Array, Float, Int -from fiesta.utils import MinMaxScalerJax, StandardScalerJax, PCAdecomposer -from fiesta import utils -from fiesta import conversions -from fiesta import models_utilities -import fiesta.train.neuralnets as fiesta_nn - import matplotlib.pyplot as plt import pickle -import h5py -from typing import Callable +################ +# TRAINING API # +################ class FluxTrainer: - """Abstract class for training a collection of surrogate""" + """Abstract class for training a surrogate model that predicts a spectral flux density array.""" name: str outdir: str parameter_names: list[str] - - preprocessing_metadata: dict[str, dict[str, float]] - - X_raw: Float[Array, "n_batch n_params"] - y_raw: dict[str, Float[Array, "n_batch n_times"]] - - X: Float[Array, "n_batch n_input_surrogate"] - y: dict[str, Float[Array, "n_batch n_output_surrogate"]] - - trained_states: dict[str, fiesta_nn.TrainState] + + train_X: Float[Array, "n_train"] + train_y: Float[Array, "n_train"] + val_X: Float[Array, "n_val"] + val_y: Float[Array, "n_val"] def __init__(self, name: str, outdir: str, - plots_dir: str = None, - ) -> None: + plots_dir: str = None, + save_preprocessed_data: bool = False) -> None: self.name = name - self.outdir = outdir # Check if directories exists, otherwise, create: + self.outdir = outdir if not os.path.exists(self.outdir): os.makedirs(self.outdir) - self.plots_dir = plots_dir - if not os.path.exists(self.plots_dir): + if self.plots_dir is not None and not os.path.exists(self.plots_dir): os.makedirs(self.plots_dir) - - + + self.save_preprocessed_data = save_preprocessed_data + # To be loaded by child classes self.parameter_names = None - - self.preprocessing_metadata = {} - self.train_X_raw = None - self.train_y_raw = None + self.train_X = None + self.train_y = None - self.val_X_raw = None - self.val_y_raw = None + self.val_X = None + self.val_y = None def __repr__(self) -> str: return f"FluxTrainer(name={self.name})" def preprocess(self): - - print("Preprocessing data by scaling to mean 0 and std 1. . .") - self.X_scaler = StandardScalerJax() - self.X = self.X_scaler.fit_transform(self.train_X_raw) - - self.y_scaler = StandardScalerJax() - self.y = self.y_scaler.fit_transform(self.train_y_raw) - - # Save the metadata - self.preprocessing_metadata["X_scaler"] = self.X_scaler - self.preprocessing_metadata["y_scaler"] = self.y_scaler - print("Preprocessing data . . . done") + raise NotImplementedError - def fit(self, + def fit(self, config: fiesta_nn.NeuralnetConfig = None, key: jax.random.PRNGKey = jax.random.PRNGKey(0), - verbose: bool = True): - """ - The config controls which architecture is built and therefore should not be specified here. - - Args: - config (nn.NeuralnetConfig, optional): _description_. Defaults to None. - """ - - # Get default choices if no config is given - if config is None: - config = fiesta_nn.NeuralnetConfig() - self.config = config - - input_ndim = len(self.parameter_names) - - - # Create neural network and initialize the state - net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) - key, subkey = jax.random.split(key) - state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) - - # Perform training loop - state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y, self.val_X, self.val_y, verbose=verbose) - # Plot and save the plot if so desired - if self.plots_dir is not None: - plt.figure(figsize=(10, 5)) - ls = "-o" - ms = 3 - plt.plot([i+1 for i in range(len(train_losses))], train_losses, ls, markersize=ms, label="Train", color="red") - plt.plot([i+1 for i in range(len(val_losses))], val_losses, ls, markersize=ms, label="Validation", color="blue") - plt.legend() - plt.xlabel("Epoch") - plt.ylabel("MSE loss") - plt.yscale('log') - plt.title("Learning curves") - plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{self.name}.png")) - plt.close() - - self.trained_state = state - - def save(self): + verbose: bool = True) -> None: + raise NotImplementedError + + def plot_learning_curve(self, train_losses, val_losses): + plt.figure(figsize=(10, 5)) + ls = "-o" + ms = 3 + plt.plot([i+1 for i in range(len(train_losses))], train_losses, ls, markersize=ms, label="Train", color="red") + plt.plot([i+1 for i in range(len(val_losses))], val_losses, ls, markersize=ms, label="Validation", color="blue") + plt.legend() + plt.xlabel("Epoch") + plt.ylabel("Loss") + plt.yscale('log') + plt.title("Learning curves") + plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{self.name}.png")) + plt.close() + + def save(self) -> None: """ - Save the trained model and all the used metadata to the outdir. + Save the trained model and all the metadata to the outdir. + The meta data is saved as a pickled dict to be read by fiesta.inference.lightcurve_model.SurrogateLightcurveModel. + The NN is saved as a pickled serialized dict using the NN.save_model method. """ # Save the metadata - if not os.path.exists(self.outdir): - os.makedirs(self.outdir) meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl") save = {} save["times"] = self.times save["nus"] = self.nus save["parameter_names"] = self.parameter_names - save.update(self.preprocessing_metadata) + save["parameter_distributions"] = self.parameter_distributions + save["X_scaler"] = self.X_scaler + save["y_scaler"] = self.y_scaler with open(meta_filename, "wb") as meta_file: pickle.dump(save, meta_file) # Save the NN - model = self.trained_state - fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{self.name}.pkl") + self.network.save_model(outfile = os.path.join(self.outdir, f"{self.name}.pkl")) - def _save_preprocessed_data(self): + def _save_preprocessed_data(self) -> None: print("Saving preprocessed data . . .") - np.savez(os.path.join(self.outdir, "afterglow_preprocessed_data.npz"), train_X=self.train_X, train_y= self.train_y, val_X = self.val_X, val_y = self.val_y) + np.savez(os.path.join(self.outdir, f"{self.name}_preprocessed_data.npz"), train_X=self.train_X, train_y= self.train_y, val_X = self.val_X, val_y = self.val_y) print("Saving preprocessed data . . . done") class PCATrainer(FluxTrainer): - + def __init__(self, name: str, outdir: str, - data_manager, + data_manager_args: dict, n_pca: Int = 100, plots_dir: str = None, - save_preprocessed_data: bool = False): + save_preprocessed_data: bool = False) -> None: + """ + FluxTrainer for training a feed-forward neural network on the PCA coefficients of the training data to predict the full 2D spectral flux density array. + Initializing will read the data and preprocess it with the DataManager class. It can then be fit with the fit() method. + To write the surrogate model to file, the save() method is to be used, which will create two pickle files (one for the metadata, one for the neural network). + + Args: + name (str): Name of the model to be trained. Will be used when saving metadata and model to file. + outdir (str): Directory where the NN and its metadata will be written to file. + data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it. + n_pca (int): Number of PCA components that will be kept when performing data preprocessing. Defaults to 100. + plots_dir (str): Directory where the loss curves will be plotted. If None, the plot will not be created. Defaults to None. + save_preprocessed_data (bool): Whether the preprocessed (i.e. PCA decomposed) training and validation data will be written to file. Defaults to False. + """ super().__init__(name = name, - outdir = outdir, - plots_dir = plots_dir) + outdir = outdir, + plots_dir = plots_dir, + save_preprocessed_data = save_preprocessed_data) self.n_pca = n_pca - self.save_preprocessed_data = save_preprocessed_data - self.data_manager = data_manager - self.parameter_names = data_manager.parameter_names - self.times = data_manager.times - self.nus = data_manager.nus - - self.plots_dir = plots_dir - if self.plots_dir is not None and not os.path.exists(self.plots_dir): - os.makedirs(self.plots_dir) + + self.data_manager = DataManager(**data_manager_args) + self.data_manager.print_file_info() + self.data_manager.pass_meta_data(self) self.preprocess() - - if save_preprocessed_data: + if self.save_preprocessed_data: self._save_preprocessed_data() def preprocess(self): + """ + Preprocessing method to get the PCA coefficients of the standardized training data. + It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method. + """ print(f"Fitting PCA model with {self.n_pca} components to the provided data.") - self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_data_from_file(self.n_pca) + self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_pca(self.n_pca) + if np.any(np.isnan(self.train_y)) or np.any(np.isnan(self.val_y)): + raise ValueError(f"Data preprocessing introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry.") print(f"PCA model accounts for a share {np.sum(self.y_scaler.explained_variance_ratio_)} of the total variance in the training data. This value is hopefully close to 1.") - self.preprocessing_metadata["X_scaler"] = self.X_scaler - self.preprocessing_metadata["y_scaler"] = self.y_scaler print("Preprocessing data . . . done") - - def load_parameter_names(self): - raise NotImplementedError - - def load_times(self): - raise NotImplementedError - - def load_raw_data(self): - raise NotImplementedError - -class DataManager: - def __init__(self, - file: str, - n_training: Int, - n_val: Int, - tmin: Float, - tmax: Float, - numin: Float = 1e9, - numax: Float = 2.5e18, - special_training: list = [], - ): + def fit(self, + config: fiesta_nn.NeuralnetConfig, + key: jax.random.PRNGKey = jax.random.PRNGKey(0), + verbose: bool = True): + """ + Method used to initialize a NN based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. + The config controls which architecture is built through config.hidden_layers. - self.file = file - self.n_training = n_training - self.n_val = n_val - - self.tmin = tmin - self.tmax = tmax - self.numin = numin - self.numax = numax + Args: + config (fiesta.train.neuralnets.NeuralnetConfig): config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to n_pca. + key (jax.random.PRNGKey, optional): jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0). + verbose (bool, optional): Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True. + """ + + self.config = config + self.config.output_size = self.n_pca # the config.output_size has to be equal to the number of PCA components + input_ndim = len(self.parameter_names) - self.special_training = special_training - self.read_metadata_from_file() - self.set_up_domain_mask() + # Create neural network and initialize the state + self.network = fiesta_nn.MLP(config = config, input_ndim = input_ndim, key = key) + + # Perform training loop + state, train_losses, val_losses = self.network.train_loop(self.train_X, self.train_y, self.val_X, self.val_y, verbose=verbose) - def read_metadata_from_file(self,)->None: - with h5py.File(self.file, "r") as f: - self.times_data = f["times"][:] - self.nus_data = f["nus"][:] - self.parameter_names = f["parameter_names"][:].astype(str).tolist() - self.n_training_exists = f["train"]["X"].shape[0] - self.n_val_exists = f["val"]["X"].shape[0] - - def set_up_domain_mask(self,)->None: - """Trims the stored data down to the time and frequency range desired for training.""" + # Plot and save the plot if so desired + if self.plots_dir is not None: + self.plot_learning_curve(train_losses, val_losses) - if self.tminself.times_data.max(): - print(f"\nWarning: provided time range {self.tmin, self.tmax} is too wide for the data stored in file. Using range {max(self.times_data.min(), self.tmin), min(self.times_data.max(), self.tmax)} instead.\n") - time_mask = np.logical_and(self.times_data>=self.tmin, self.times_data<=self.tmax) - self.times = self.times_data[time_mask] - self.n_times = len(self.times) - if self.numinself.nus_data.max(): - print(f"\nWarning: provided frequency range {self.numin, self.numax} is too wide for the data stored in file. Using range {max(self.nus_data.min(), self.numin), min(self.nus_data.max(), self.numax)} instead.\n") - nu_mask = np.logical_and(self.nus_data>=self.numin, self.nus_data<=self.numax) - self.nus = self.nus_data[nu_mask] - self.n_nus = len(self.nus) +class CVAETrainer(FluxTrainer): - mask = nu_mask[:, None] & time_mask - self.mask = mask.flatten() - - def get_data_from_file(self,): - with h5py.File(self.file, "r") as f: - if self.n_training>self.n_training_exists: - raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") - self.train_X_raw = f["train"]["X"][:self.n_training] - self.train_y_raw = f["train"]["y"][:self.n_training, self.mask] + def __init__(self, + name: str, + outdir, + data_manager_args, + image_size: tuple[Int], + plots_dir: str = None, + save_preprocessed_data=False)->None: + """ + FluxTrainer for training a conditional variational autoencoder on the log fluxes of the training data to predict the full 2D spectral flux density array. + Initializing will read the data and preprocess it with the DataManager class. It can then be fit with the fit() method. + To write the surrogate model to file, the save() method is to be used, which will create two pickle files (one for the metadata, one for the neural network). + + Args: + name (str): Name of the model to be trained. Will be used when saving metadata and model to file. + outdir (str): Directory where the NN and its metadata will be written to file. + data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it. + image_size (tuple(Int)): Size the 2D flux array will be down-sampled to with jax.image.resize when performing data preprocessing. + plots_dir (str): Directory where the loss curves will be plotted. If None, the plot will not be created. Defaults to None. + save_preprocessed_data (bool): Whether the preprocessed (i.e. down sampled and standardized) training and validation data will be written to file. Defaults to False. + """ + + super().__init__(name = name, + outdir = outdir, + plots_dir = plots_dir, + save_preprocessed_data = save_preprocessed_data) + + self.data_manager = DataManager(**data_manager_args) + self.data_manager.print_file_info() + self.data_manager.pass_meta_data(self) - for label in self.special_training: - self.train_X_raw = np.concatenate((self.train_X_raw, f["special_train"][label]["X"][:])) - self.train_y_raw = np.concatenate((self.train_y_raw, f["special_train"][label]["y"][:, self.mask])) + self.image_size = image_size - if self.n_val>self.n_val_exists: - raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to validate with {self.n_val} data points.") - self.val_X_raw = f["val"]["X"][:self.n_val] - self.val_y_raw = f["val"]["y"][:self.n_val, self.mask] + self.preprocess() + if self.save_preprocessed_data: + self._save_preprocessed_data() + + def preprocess(self)-> None: + """ + Preprocessing method to get the down_sample arrays of the standardized training data. + It assigns the attributes self.train_X, self.train_y, self.val_X, self.val_y that are passed to the fitting method. + """ + print(f"Preprocessing data by resampling flux array to {self.image_size} and standardizing.") + self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_cVAE(self.image_size) + if np.any(np.isnan(self.train_y)) or np.any(np.isnan(self.val_y)): + raise ValueError(f"Data preprocessing introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry.") + print("Preprocessing data . . . done") - def preprocess_data_from_file(self, n_components: int)->None: - Xscaler, yscaler = StandardScalerJax(), PCAdecomposer(n_components=n_components) - with h5py.File(self.file, "r") as f: - # preprocess the training data - if self.n_training>self.n_training_exists: - raise ValueError(f"Only {self.n_training_exists} entries in file, not enough to train with {self.n_training} data points.") - - train_X_raw = f["train"]["X"][:self.n_training] - for label in self.special_training: - train_X_raw = np.concatenate((train_X_raw, f["special_train"][label]["X"][:])) - train_X = Xscaler.fit_transform(train_X_raw) - - loaded = f["train"]["y"][:15_000, self.mask] - if np.any(np.isinf(loaded)): - raise ValueError(f"Found inftys in training data.") - yscaler.fit(loaded) # only load 15k cause otherwise the array might get too large - train_y = np.empty((self.n_training, n_components)) - n_loaded = 0 - for chunk in f["train"]["y"].iter_chunks(): - loaded = f["train"]["y"][chunk][:, self.mask] - if np.any(np.isinf(loaded)): - raise ValueError(f"Found inftys in training data.") - train_y[n_loaded:n_loaded+len(loaded)] = yscaler.transform(loaded) - n_loaded += len(loaded) - if n_loaded >= self.n_training: - break - for label in self.special_training: - special_train_y = yscaler.transform(f["special_train"][label]["y"][:, self.mask]) - train_y = np.concatenate((train_y, special_train_y)) + def fit(self, + config: fiesta_nn.NeuralnetConfig = None, + key: jax.random.PRNGKey = jax.random.PRNGKey(0), + verbose: bool = True) -> None: + """ + Method used to initialize the autoencoder based on the architecture specified in config and then fit it based on the learning rate and epoch number specified in config. + The config controls which architecture is built through config.hidden_layers. The encoder and decoder share the hidden_layers argument, though the layers for the decoder are implemented in reverse order. + + Args: + config (fiesta.train.neuralnets.NeuralnetConfig): config that needs to specify at least the network output, hidden_layers, learning rate, and learning epochs. Its output_size must be equal to the product of self.image_size. + key (jax.random.PRNGKey, optional): jax.random.PRNGKey used to initialize the parameters of the network. Defaults to jax.random.PRNGKey(0). + verbose (bool, optional): Whether the train and validation loss is printed to terminal in certain intervals. Defaults to True. + """ - # preprocess validation data - if self.n_val>self.n_val_exists: - raise ValueError(f"Only {self.n_val_exists} entries in file, not enough to train with {self.n_val} data points.") - val_X_raw = f["val"]["X"][:self.n_val] - val_X = Xscaler.transform(val_X_raw) - val_y_raw = f["val"]["y"][:self.n_val, self.mask] - val_y = yscaler.transform(val_y_raw) + self.config = config + config.output_size = int(np.prod(self.image_size)) # Output must be equal to the product of self.image_size. - return train_X, train_y, val_X, val_y, Xscaler, yscaler - - - def pass_data(self, object): - object.parameter_names = self.parameter_names - object.train_X_raw = self.train_X_raw - object.train_y_raw = self.train_y_raw - object.val_X_raw = self.val_X_raw - object.val_y_raw = self.val_y_raw - object.times = self.times - object.nus = self.nus - + self.network = fiesta_nn.CVAE(config = self.config, conditional_dim = len(self.parameter_names), key = key) + state, train_losses, val_losses = self.network.train_loop(self.train_X, self.train_y, self.val_X, self.val_y, verbose = verbose) - def print_file_info(self,): - with h5py.File(self.file, "r") as f: - print(f"Times: {f['times'][0]} {f['times'][-1]}") - print(f"Nus: {f['nus'][0]} {f['nus'][-1]}") - print(f"Parameter distributions: {f['parameter_distributions'][()].decode('utf-8')}") - print("\n") - print(f"Training data: {self.n_training_exists}") - print(f"Validation data: {self.n_val_exists}") - print(f"Test data: {f['test']['X'].shape[0]}") - print("Special data:") - for key in f['special_train'].keys(): - print(f"\t {key}: {f['special_train'][key]['X'].shape[0]} description: {f['special_train'][key].attrs['comment']}") - print("\n \n") \ No newline at end of file + # Plot and save the plot if so desired + if self.plots_dir is not None: + self.plot_learning_curve(train_losses, val_losses) diff --git a/src/fiesta/train/LightcurveTrainer.py b/src/fiesta/train/LightcurveTrainer.py index e0f486a..5da0dbf 100644 --- a/src/fiesta/train/LightcurveTrainer.py +++ b/src/fiesta/train/LightcurveTrainer.py @@ -4,83 +4,62 @@ import numpy as np import jax -import jax.numpy as jnp from jaxtyping import Array, Float, Int +from typing import Dict from fiesta.utils import MinMaxScalerJax -from fiesta import utils from fiesta.utils import Filter -from fiesta import conversions -from fiesta.constants import days_to_seconds, c -from fiesta import models_utilities +from fiesta.train.DataManager import DataManager import fiesta.train.neuralnets as fiesta_nn import matplotlib.pyplot as plt import pickle -from typing import Callable -import tqdm -import afterglowpy as grb +################ +# TRAINING API # +################ -class SurrogateTrainer: +class LightcurveTrainer: """Abstract class for training a collection of surrogate models per filter""" - + name: str outdir: str filters: list[Filter] parameter_names: list[str] - - preprocessing_metadata: dict[str, dict[str, float]] - - # TODO: why do we have so many datasets? - X: Float[Array, "n_batch n_input_surrogate"] - y: dict[str, Float[Array, "n_batch n_output_surrogate"]] - - X_raw: Float[Array, "n_batch n_params"] - y_raw: dict[str, Float[Array, "n_batch n_times"]] - - train_X: Float[Array, "n_batch n_params"] - train_y: dict[str, Float[Array, "n_batch n_times"]] - - val_X: Float[Array, "n_batch n_params"] - val_y: dict[str, Float[Array, "n_batch n_times"]] - - train_X_raw: Float[Array, "n_batch n_params"] - train_y_raw: dict[str, Float[Array, "n_batch n_times"]] - - val_X_raw: Float[Array, "n_batch n_params"] - val_y_raw: dict[str, Float[Array, "n_batch n_times"]] - - trained_states: dict[str, fiesta_nn.TrainState] + + train_X: Float[Array, "n_train"] + train_y: Dict[str, Float[Array, "n"]] + val_X: Float[Array, "n_val"] + val_y: Dict[str, Float[Array, "n"]] def __init__(self, name: str, outdir: str, - save_raw_data: bool = False, + plots_dir: str = None, save_preprocessed_data: bool = False) -> None: self.name = name - self.outdir = outdir # Check if directories exists, otherwise, create: + self.outdir = outdir if not os.path.exists(self.outdir): os.makedirs(self.outdir) - - self.save_raw_data = save_raw_data + self.plots_dir = plots_dir + if self.plots_dir is not None and not os.path.exists(self.plots_dir): + os.makedirs(self.plots_dir) + self.save_preprocessed_data = save_preprocessed_data - + # To be loaded by child classes self.filters = None self.parameter_names = None - self.preprocessing_metadata = {} - - self.X_raw = None - self.y_raw = None - self.X = None - self.y = None - self.weights = None + self.train_X = None + self.train_y = None + + self.val_X = None + self.val_y = None def __repr__(self) -> str: - return f"SurrogateTrainer(name={self.name})" + return f"LightcurveTrainer(name={self.name})" def preprocess(self): @@ -88,51 +67,40 @@ def preprocess(self): self.X_scaler = MinMaxScalerJax() self.X = self.X_scaler.fit_transform(self.X_raw) - self.y_scalers: dict[str, MinMaxScalerJax] = {} + self.y_scaler: dict[str, MinMaxScalerJax] = {} self.y = {} for filt in self.filters: y_scaler = MinMaxScalerJax() self.y[filt.name] = y_scaler.fit_transform(self.y_raw[filt.name]) - self.y_scalers[filt.name] = y_scaler - - # Save the metadata - self.preprocessing_metadata["X_scaler_min"] = self.X_scaler.min_val - self.preprocessing_metadata["X_scaler_max"] = self.X_scaler.max_val - self.preprocessing_metadata["y_scaler_min"] = {filt.name: self.y_scalers[filt.name].min_val for filt in self.filters} - self.preprocessing_metadata["y_scaler_max"] = {filt.name: self.y_scalers[filt.name].max_val for filt in self.filters} + self.y_scaler[filt.name] = y_scaler print("Preprocessing data . . . done") def fit(self, - config: fiesta_nn.NeuralnetConfig = None, + config: fiesta_nn.NeuralnetConfig, key: jax.random.PRNGKey = jax.random.PRNGKey(0), - verbose: bool = True): + verbose: bool = True) -> None: """ The config controls which architecture is built and therefore should not be specified here. Args: config (nn.NeuralnetConfig, optional): _description_. Defaults to None. """ - - # Get default choices if no config is given - if config is None: - config = fiesta_nn.NeuralnetConfig() - self.config = config - - trained_states = {} + self.config = config + self.models = {} input_ndim = len(self.parameter_names) + for filt in self.filters: print(f"\n\n Training {filt.name}... \n\n") # Create neural network and initialize the state - net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) - key, subkey = jax.random.split(key) - state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) - - # Perform training loop - state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) + net = fiesta_nn.MLP(config = config, input_ndim = input_ndim, key = key) + # Perform training loop + state, train_losses, val_losses = net.train_loop(self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) + self.models[filt.name] = net + # Plot and save the plot if so desired if self.plots_dir is not None: plt.figure(figsize=(10, 5)) @@ -147,576 +115,89 @@ def fit(self, plt.title("Learning curves") plt.savefig(os.path.join(self.plots_dir, f"learning_curves_{filt.name}.png")) plt.close() - - trained_states[filt.name] = state - - self.trained_states = trained_states def save(self): """ Save the trained model and all the used metadata to the outdir. """ - if not os.path.exists(self.outdir): - os.makedirs(self.outdir) - + # Save the metadata meta_filename = os.path.join(self.outdir, f"{self.name}_metadata.pkl") - # FIXME: this should not be in this class - if os.path.exists(meta_filename): - with open(meta_filename, "rb") as meta_file: - save = pickle.load(meta_file) - if not np.array_equal(save["times"], self.times): # check whether the metadata from previously trained filters agrees - raise Exception(f"The time array needs to coincide with the time array for previous filters: {save['times']}") - if not np.array_equal(save["parameter_names"], self.parameter_names): - raise Exception(f"The parameters need to coincide with the parameters for previous filters: {save['parameter_names']}") - else: - save = {} + save = {} save["times"] = self.times save["parameter_names"] = self.parameter_names - # TODO: see if we can save the jet_type here somewhat more self-consistently + save["parameter_distributions"] = self.parameter_distributions + save["X_scaler"] = self.X_scaler + save["y_scaler"] = self.y_scaler - for filt in self.filters: - model = self.trained_states[filt.name] - fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl") - save[filt.name] = self.preprocessing_metadata[filt.name] - with open(meta_filename, "wb") as meta_file: pickle.dump(save, meta_file) - def _save_raw_data(self): - print("Saving raw data . . .") - np.savez(os.path.join(self.outdir, "raw_data_training.npz"), X_raw=self.train_X_raw, **self.train_y_raw) - np.savez(os.path.join(self.outdir, "raw_data_validation.npz"), X_raw=self.val_X_raw, **self.val_y_raw) - print("Saving raw data . . . done") - - def _save_preprocessed_data(self): + # Save the NN + for filt in self.filters: + model = self.models[filt.name] + model.save_model(outfile = os.path.join(self.outdir, f"{self.name}_{filt.name}.pkl")) + + def _save_preprocessed_data(self) -> None: print("Saving preprocessed data . . .") - np.savez(os.path.join(self.outdir, "preprocessed_data_training.npz"), X=self.train_X, **self.train_y) - np.savez(os.path.join(self.outdir, "preprocessed_data_validation.npz"), X=self.val_X, **self.val_y) + np.savez(os.path.join(self.outdir, f"{self.name}_preprocessed_data.npz"), train_X=self.train_X, train_y = self.train_y, val_X = self.val_X, val_y = self.val_y) print("Saving preprocessed data . . . done") -class SVDSurrogateTrainer(SurrogateTrainer): - - outdir: str - svd_ncoeff: Int - tmin: Float - tmax: Float - dt: Float - times: Float[Array, "n_times"] - plots_dir: str - save_raw_data: bool - save_preprocessed_data: bool +class SVDTrainer(LightcurveTrainer): def __init__(self, name: str, outdir: str, - filters: list[str] = None, - svd_ncoeff: Int = 10, - validation_fraction: Float = 0.2, - tmin: Float = None, - tmax: Float = None, - dt: Float = None, + filters: list[str], + data_manager_args: dict, + svd_ncoeff: Int = 50, plots_dir: str = None, - save_raw_data: bool = False, - save_preprocessed_data: bool = False - ): - + save_preprocessed_data: bool = False) -> None: """ - Initialize the surrogate model trainer that uses an SVD. The initialization also takes care of reading data and preprocessing it, but does not automatically fit the model. Users may want to inspect the data before fitting the model. - - Note: currently, only models of Bulla type .dat files are supported + Initialize the surrogate model trainer that decomposes the training data into its SVD coefficients. The initialization also takes care of reading data and preprocessing it, but does not automatically fit the model. Users may want to inspect the data before fitting the model. Args: name (str): Name of the surrogate model. Will be used - lc_dir (list[str]): Directory where all the raw light curve files, to be read and processed into a surrogate model. - outdir (str): Directory where the trained surrogate model has to be saved. - filters (list[str], optional): List of all the filters used in the light curve files and for which surrogate has to be trained. If None, all the filters will be used. Defaults to None. - svd_ncoeff: int : Number of SVD coefficients to use in data reduction during training. Defaults to 10. - validation_fraction (Float, optional): Fraction of the data to be used for validation. Defaults to 0.2. - tmin (Float, optional): Minimum time in days of the light curve, all data before is discarded. Defaults to 0.05. - tmax (Float, optional): Maximum time in days of the light curve, all data after is discarded. Defaults to 14.0. - dt (Float, optional): Time step in the light curve. Defaults to 0.1. + outdir (str): Directory where the trained surrogate model is to be saved. + filters (list[str]): List of the filters for which the surrogate has to be trained. These have to be either bandpasses from sncosmo or specifiy the frequency through endign with GHz or keV. + data_manager_args (dict): data_manager_args (dict): Arguments for the DataManager class instance that will be used to read the data from the .h5 file in outdir and preprocess it. + svd_ncoeff (int, optional) : Number of SVD coefficients to use in data reduction during training. Defaults to 50. plots_dir (str, optional): Directory where the plots of the training process will be saved. Defaults to None, which means no plots will be generated. - save_raw_data (bool, optional): If True, the raw data will be saved in the outdir. Defaults to False. - save_preprocessed_data: If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. + save_preprocessed_data (bool, optional): If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. """ - - super().__init__(name=name, outdir=outdir) - self.plots_dir = plots_dir - - if self.plots_dir is not None and not os.path.exists(self.plots_dir): - os.makedirs(self.plots_dir) - self.validation_fraction = validation_fraction - + super().__init__(name = name, + outdir = outdir, + plots_dir = plots_dir, + save_preprocessed_data = save_preprocessed_data) + self.svd_ncoeff = svd_ncoeff - self.tmin = tmin - self.tmax = tmax - self.dt = dt - self.save_raw_data = save_raw_data - self.save_preprocessed_data = save_preprocessed_data + self.data_manager = DataManager(**data_manager_args) + self.data_manager.print_file_info() + self.data_manager.pass_meta_data(self) self.load_filters(filters) - self.load_times() - self.load_parameter_names() - - self.load_raw_data() self.preprocess() - - if save_preprocessed_data: + if self.save_preprocessed_data: self._save_preprocessed_data() - if save_raw_data: - self._save_raw_data() - - def load_parameter_names(self): - raise NotImplementedError - - def load_times(self): - raise NotImplementedError - - def load_filters(self, filters: list[str] = None): - raise NotImplementedError - - def load_raw_data(self): - raise NotImplementedError - - def preprocess(self): - """ - Preprocess the data. This includes scaling the inputs and outputs, performing SVD decomposition, and saving the necessary metadata for later use. - """ - # Scale inputs - X_scaler = MinMaxScalerJax() - self.train_X = X_scaler.fit_transform(self.train_X_raw) # fit the scaler to the training data - self.val_X = X_scaler.transform(self.val_X_raw) # transform the val data - - # Scale outputs, do SVD and save into y - self.train_y = {filt.name: [] for filt in self.filters} - self.val_y = {filt.name: [] for filt in self.filters} - - print(f"Rescaling the training and validation data for filters {[filter.name for filter in self.filters]}") - for filt in tqdm.tqdm(self.filters): - - y_scaler = MinMaxScalerJax() - - completely_problematic = np.where(np.all(np.isinf(self.train_y_raw[filt.name]), axis = 1))[0] - problematic = np.unique(np.where(np.isinf(self.train_y_raw[filt.name]))[0]) - - if len(problematic)!=0: - raise Exception(f"There were infs in the magnitudes for filter {filt.name}.") - - data = y_scaler.fit_transform(self.train_y_raw[filt.name]) - - # Do SVD decomposition on the training data - UA, _, VA = np.linalg.svd(data, full_matrices=True) - VA = VA.T - - n, n = UA.shape - m, m = VA.shape - # This is taken over from NMMA - cAmat = np.zeros((self.svd_ncoeff, n)) - cAvar = np.zeros((self.svd_ncoeff, n)) - for i in range(n): - ErrorLevel = 1.0 - cAmat[:, i] = np.dot( - data[i, :], VA[:, : self.svd_ncoeff] - ) - errors = ErrorLevel * np.ones_like(data[i, :]) - cAvar[:, i] = np.diag( - np.dot( - VA[:, : self.svd_ncoeff].T, - np.dot(np.diag(np.power(errors, 2.0)), VA[:, : self.svd_ncoeff]), - ) - ) - - self.train_y[filt.name] = cAmat.T # Transpose to get the shape (n_batch, n_svd_coeff) - - # Do SVD decomposition on the validation data - val_data = y_scaler.transform(self.val_y_raw[filt.name]) - cAmat = np.zeros((self.svd_ncoeff, self.n_val_data)) - for i in range(self.n_val_data): - cAmat[:,i] = np.dot( - val_data[i,:], VA[:, : self.svd_ncoeff] - ) - - self.val_y[filt.name] = cAmat.T # Transpose to get the shape (n_val, n_svd_coeff) - - #Save the scalers - self.preprocessing_metadata[filt.name] = {"VA": VA, "X_scaler_max": X_scaler.max_val, "X_scaler_min": X_scaler.min_val, "y_scaler": y_scaler, "svd_ncoeff": self.svd_ncoeff} - - - def __repr__(self) -> str: - return f"SVDSurrogateTrainer(name={self.name}, lc_dir={self.lc_dir}, outdir={self.outdir}, filters={self.filters})" - - -class BullaSurrogateTrainer(SVDSurrogateTrainer): - - _times_grid: Float[Array, "n_times"] - extract_parameters_function: Callable - data_dir: str - - # Check if supported - def __init__(self, - name: str, - outdir: str, - filters: list[str] = None, - data_dir: list[str] = None, - svd_ncoeff: Int = 10, - validation_fraction: Float = 0.2, - tmin: Float = None, - tmax: Float = None, - dt: Float = None, - plots_dir: str = None, - save_raw_data: bool = False, - save_preprocessed_data: bool = False): - - # Check if this version of Bulla is supported - supported_models = list(models_utilities.SUPPORTED_BULLA_MODELS) - if name not in supported_models: - raise ValueError(f"Bulla model version {name} is not supported yet. Supported models are: {supported_models}") - - # Get the function to extract parameters - self.extract_parameters_function = models_utilities.EXTRACT_PARAMETERS_FUNCTIONS[name] - self.data_dir=data_dir - - super().__init__(name=name, - outdir=outdir, - filters=filters, - svd_ncoeff=svd_ncoeff, - validation_fraction=validation_fraction, - tmin=tmin, - tmax=tmax, - dt=dt, - plots_dir=plots_dir, - save_raw_data=save_raw_data, - save_preprocessed_data=save_preprocessed_data) - - - def load_times(self): - """ - Fetch the time grid from the Bulla .dat files or create from given input - """ - self._times_grid = utils.get_times_bulla_file(self.lc_files[0]) - if self.tmin is None or self.tmax is None or self.dt is None: - print("No time range given, using grid times") - self.times = self._times_grid - self.tmin = self.times[0] - self.tmax = self.times[-1] - self.dt = self.times[1] - self.times[0] - else: - self.times = np.arange(self.tmin, self.tmax + self.dt, self.dt) - - def load_parameter_names(self): - self.parameter_names = models_utilities.BULLA_PARAMETER_NAMES[self.name] - - def load_filters(self, filters: list[str] = None): - """ - If no filters are given, we will read the filters from the first Bulla lightcurve file and assume all files have the same filters - - Args: - filters (list[str], optional): List of filters to be used in the training. Defaults to None. - """ - filenames: list[str] = os.listdir(self.data_dir) - self.lc_files = [os.path.join(self.data_dir, f) for f in filenames if f.endswith(".dat")] - if filters is None: - filters = utils.get_filters_bulla_file(self.lc_files[0], drop_times=True) - self.filters = [] - - # Create Filters objects for each filter - for filter in filters: - self.filters.append(Filter(filter)) - - def _read_files(self) -> tuple[dict[str, Float[Array, " n_batch n_params"]], Float[Array, "n_batch n_times"]]: - """ - Read the photometry files and interpolate the NaNs. - Output will be an array of shape (n_filters, n_batch, n_times) - - Args: - lc_files (list[str]): List of all the raw light curve files, to be read and processed into a surrogate model. - - Returns: - tuple[dict[str, Float[Array, " n_batch n_times"]], Float[Array, "n_batch n_params"]]: First return value is an array of all the parameter values extracted from the files. Second return value is a dictionary containing the filters and corresponding light curve data which has shape (n_batch, n_times). - """ - - # Fetch the result for each filter and add it to already existing dataset - data = {filt: [] for filt in self.filters} - for i, filename in enumerate(tqdm.tqdm(self.lc_files)): - # Get a dictionary with keys being the filters and values being the light curve data - lc_data = utils.read_single_bulla_file(filename) - for filt in self.filters: - # TODO: improve this cumbersome thing - this_data = lc_data[filt.name] - if i == 0: - data[filt.name] = this_data - else: - data[filt.name] = np.vstack((data[filt.name], this_data)) - - # Fetch the parameter values of this file - params = self.extract_parameters_function(filename) - # TODO: improve this cumbersome thing - if i == 0: - parameter_values = params - else: - parameter_values = np.vstack((parameter_values, params)) - - return parameter_values, data - - def load_raw_data(self): - print("Reading data files and interpolating NaNs . . .") - X_raw, y = self._read_files() - y_raw = utils.interpolate_nans(y, self._times_grid, self.times) - if self.save_raw_data: - np.savez(os.path.join(self.outdir, "raw_data.npz"), X_raw=X_raw, times=self.times, times_grid=self._times_grid, **y_raw) - - # split here into training and validating data - self.n_val_data = int(self.validation_fraction*len(X_raw)) - self.n_training_data = len(X_raw) - self.n_val_data - mask = np.zeros(len(X_raw) ,dtype = bool) - mask[np.random.choice(len(X_raw), self.n_val_data, replace = False)] = True - - self.train_X_raw, self.val_X_raw = X_raw[~mask], X_raw[mask] - self.train_y_raw, self.val_y_raw = {}, {} - - print("self.filters") - print(self.filters) - - for filt in self.filters: - self.train_y_raw[filt.name] = y_raw[filt.name][~mask] - self.val_y_raw[filt.name] = y_raw[filt.name][mask] - - -# TODO: perhaps rename to *_1D, since it is only for 1D light curves, and we might want to get support for 2D by incorporating the frequencies... Unsure about the approach here - -class AfterglowpyTrainer(SVDSurrogateTrainer): - - parameter_grid: dict[str, list[Float]] - n_training_data: Int - fixed_parameters: dict[str, Float] - jet_type: Int - use_log_spacing: bool - _times_afterglowpy: Float[Array, "n_times"] - nus: dict[str, Float] - - def __init__(self, - name: str, - outdir: str, - filters: list[str], - parameter_grid: dict[str, list[float]], - weight_grids: dict = None, - n_training_data: Int = 5000, - jet_type: Int = -1, - fixed_parameters: dict[str, Float] = {}, - tmin: Float = 0.1, - tmax: Float = 1000, - n_times: Int = 100, - use_log_spacing: bool = True, - validation_fraction: float = 0.2, - plots_dir: str = None, - svd_ncoeff: Int = 10, - save_raw_data: bool = False, - save_preprocessed_data: bool = False, - remake_training_data = False, - ): - """ - Initialize the surrogate model trainer. The initialization also takes care of reading data and preprocessing it, but does not automatically fit the model. Users may want to inspect the data before fitting the model. - - Args: - name (str): Name given to the model - outdir (str): Output directory to save the trained model - parameter_grid (dict[str, list[Float]]): Dictionary containing the grid points for each parameter, i.e., the parameter values on which the surrogate will be trained. The keys should be the parameter names and the values should be a list.. - jet_type (Int): Type of jet for the afterglowpy, -1 is tophat, 0 is Gaussian, 4 is PowerLaw - fixed_parameters (dict[str, Float]) : values of the afterglowpy parameters that should be kept fixed for the surrogate model - tmin (Float, optional): Minimum time in days of the light curve, all data before is discarded. Defaults to 0.1. - tmax (Float, optional): Maximum time in days of the light curve, all data after is discarded. Defaults to 1000. - n_times: number of time nodes for the training light curve data - use_log_spacing: bool : whether the time nodes of the training light curve data should be log10 spaced - validation_fraction (Float, optional): Fraction of the data to be used for validation. Defaults to 0.2. - plots_dir : str : outdir for the plots - svd_ncoeff: int : Number of SVD coefficients to use in data reduction during training. Defaults to 10. - save_raw_data (bool, optional): If True, the raw data will be saved in the outdir. Defaults to False. - save_preprocessed_data: If True, the preprocessed data (reduced, rescaled) will be saved in the outdir. Defaults to False. - """ - - - self.n_times = n_times - dt = (tmax - tmin) / n_times - self.parameter_grid = parameter_grid - self.weight_grids = weight_grids - if self.weight_grids is None: - self.weight_grids = {p: np.full_like(self.parameter_grid[p], 1/len(self.parameter_grid[p])) for p in self.parameter_grid.keys()} - - self.fixed_parameters = fixed_parameters - self.use_log_spacing = use_log_spacing - - # Check jet type before saving - supported_jet_types = [-1, 0, 4] - if jet_type not in supported_jet_types: - raise ValueError(f"Jet type {jet_type} is not supported. Supported jet types are: {supported_jet_types}") - self.jet_type = jet_type - self.remake_training_data = remake_training_data - - self.n_training_data = n_training_data - self.validation_fraction = validation_fraction - self.n_val_data = int(self.n_training_data * self.validation_fraction/(1-self.validation_fraction)) - - super().__init__(name=name, - outdir=outdir, - filters=filters, - svd_ncoeff=svd_ncoeff, - validation_fraction=validation_fraction, - tmin=tmin, - tmax=tmax, - dt=dt, - plots_dir=plots_dir, - save_raw_data=save_raw_data, - save_preprocessed_data=save_preprocessed_data) - - - def load_filters(self, filters: list[str]): + def load_filters(self, filters): self.filters = [] - for filter in filters: - try: - self.filters.append(Filter(filter)) - except: - raise Exception(f"Filter {filter} not available.") - - def load_times(self): - if self.use_log_spacing: - times = np.logspace(np.log10(self.tmin), np.log10(self.tmax), num=self.n_times) - else: - times = np.linspace(self.tmin, self.tmax, num=self.n_times) - self.times = times - self._times_afterglowpy = self.times * days_to_seconds # afterglowpy takes seconds as input - - def load_parameter_names(self): - self.parameter_names = list(self.parameter_grid.keys()) - - def load_raw_data(self): - data_files_exist = os.path.exists(self.outdir+"/raw_data_training.npz") and os.path.exists(self.outdir+"/raw_data_validation.npz") - if data_files_exist and not self.remake_training_data: - self.train_X_raw, self.train_y_raw, self.val_X_raw, self.val_y_raw = self._read_files() - else: - self.create_raw_data() - - def create_raw_data(self): - """ - Create a grid of training data with specified settings and generate the output files for them. - - TODO: for now we train per filter, but best to change this! - """ - # Create training data - nus = jnp.array([filt.nu for filt in self.filters]) - X_raw = np.empty((self.n_training_data, len(self.parameter_names))) - y_raw = {filt.name: np.empty((self.n_training_data, len(self.times))) for filt in self.filters} - - for j, key in enumerate(self.parameter_grid.keys()): - X_raw[:,j] = np.random.choice(self.parameter_grid[key], size = self.n_training_data, replace = True, p = self.weight_grids[key]) - - - print(f"Creating the afterglowpy training dataset on grid with {self.n_training_data} points.") - for i in tqdm.tqdm(range(self.n_training_data)): - param_dict = dict(zip(self.parameter_names, X_raw[i])) - param_dict.update(self.fixed_parameters) - mJys = self._call_afterglowpy(param_dict, nus) - for k, filt in enumerate(self.filters): - y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys[k]) - - self.train_X_raw = X_raw - self.train_y_raw = y_raw - - - # Create validation data - X_raw = np.empty((self.n_val_data, len(self.parameter_names))) - y_raw = {filt.name: np.empty((self.n_val_data, len(self.times))) for filt in self.filters} - - print(f"Creating the afterglowpy validation dataset on {self.n_val_data} random points within grid.") - for i in tqdm.tqdm(range(self.n_val_data)): - X_raw[i] = [np.random.uniform(self.parameter_grid[p][0], self.parameter_grid[p][-1]) for p in self.parameter_names] - param_dict = dict(zip(self.parameter_names, X_raw[i])) - param_dict.update(self.fixed_parameters) - mJys = self._call_afterglowpy(param_dict, nus) - - for k, filt in enumerate(self.filters): - y_raw[filt.name][i] = conversions.mJys_to_mag_np(mJys[k]) - - self.val_X_raw = X_raw - self.val_y_raw = y_raw - - - - def _read_files(self,): - raw_data_train = np.load(self.outdir+"/raw_data_training.npz") - raw_data_validation = np.load(self.outdir+'/raw_data_validation.npz') - - training_y_raw = {} - val_y_raw = {} - - for filt in self.filters: - training_y_raw[filt.name] = raw_data_train[filt.name] - val_y_raw[filt.name] = raw_data_validation[filt.name] - - self.n_training_data = len(raw_data_train["X_raw"]) - self.n_val_data = len(raw_data_validation["X_raw"]) - return raw_data_train["X_raw"], training_y_raw, raw_data_validation["X_raw"], val_y_raw - - def _set_weights(self,): # TODO: dev legacy - weights = {filt.name: jnp.ones(self.n_training_data) for filt in self.filters} - - - for j, p in enumerate(self.parameter_names): - bins = (self.parameter_grid[p][:-1] + self.parameter_grid[p][1:])/2 - bins = [self.parameter_grid[p][0] ,*bins, self.parameter_grid[p][-1]+0.1] # the right boundary of the bins needs to be a little larger because of digitize - indices = np.digitize(self.train_X_raw[:,j], bins) - 1 - for filt in self.filters: - weights[filt.name] *= self.weight_grids[filt.name][p][indices] - - self.weights = {filt.name: weights[filt.name]/np.sum(weights[filt.name]) for filt in self.filters} - + for filt in filters: + Filt = Filter(filt) + if Filt.nus[0] < self.nus[0] or Filt.nus[-1] > self.nus[-1]: + raise ValueError(f"Filter {filt} exceeds the frequency range of the training data.") + self.filters.append(Filt) - - def _call_afterglowpy(self, - params_dict: dict[str, Float], - nus) -> Float[Array, "n_times"]: + def preprocess(self): """ - Call afterglowpy to generate a single flux density output, for a given set of parameters. Note that the parameters_dict should contain all the parameters that the model requires, as well as the nu value. - The output will be a set of mJys. - - Args: - Float[Array, "n_times"]: The flux density in mJys at the given times. + Preprocessing method to get the SVD coefficients of the training and validation data. This includes scaling the inputs and outputs, as well as performing SVD decomposition. """ - - # Preprocess the params_dict into the format that afterglowpy expects, which is usually called Z - Z = {} - - Z["jetType"] = params_dict.get("jetType", self.jet_type) - Z["specType"] = params_dict.get("specType", 0) - Z["z"] = params_dict.get("z", 0.0) - Z["xi_N"] = params_dict.get("xi_N", 1.0) - - Z["E0"] = 10 ** params_dict["log10_E0"] - Z["n0"] = 10 ** params_dict["log10_n0"] - Z["p"] = params_dict["p"] - Z["epsilon_e"] = 10 ** params_dict["log10_epsilon_e"] - Z["epsilon_B"] = 10 ** params_dict["log10_epsilon_B"] - Z["d_L"] = 3.086e19 # fix at 10 pc, so that AB magnitude equals absolute magnitude - if "inclination_EM" in list(params_dict.keys()): - Z["thetaObs"] = params_dict["inclination_EM"] - else: - Z["thetaObs"] = params_dict["thetaObs"] - - if self.jet_type == -1: - Z["thetaCore"] = params_dict["thetaCore"] - - if "thetaWing" in list(params_dict.keys()): #for Gaussian and power law jets - Z["thetaWing"] = params_dict["thetaWing"] - Z["thetaCore"] = params_dict["xCore"]*params_dict["thetaWing"] - - if self.jet_type == 4: - Z["b"] = params_dict["b"] - - # Afterglowpy returns flux in mJys - tt, nunu = np.meshgrid(self._times_afterglowpy, nus) - mJys = grb.fluxDensity(tt, nunu, **Z) - return mJys \ No newline at end of file + print(f"Decomposing training data to SVD coefficients.") + self.train_X, self.train_y, self.val_X, self.val_y, self.X_scaler, self.y_scaler = self.data_manager.preprocess_svd(self.svd_ncoeff, self.filters) + for key in self.train_y.keys(): + if np.any(np.isnan(self.train_y[key])) or np.any(np.isnan(self.val_y[key])): + raise ValueError(f"Data preprocessing for {key} introduced nans. Check raw data for nans of infs or vanishing variance in a specific entry.") + print(f"Preprocessing data . . . done") diff --git a/src/fiesta/train/__init__.py b/src/fiesta/train/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/src/fiesta/train/neuralnets.py b/src/fiesta/train/neuralnets.py index 177b04d..f649569 100644 --- a/src/fiesta/train/neuralnets.py +++ b/src/fiesta/train/neuralnets.py @@ -1,4 +1,3 @@ -from typing import Sequence, Callable import time import jax @@ -12,6 +11,8 @@ import optax import pickle +import fiesta.train.nn_architectures as nn + ############### ### CONFIGS ### ############### @@ -20,6 +21,7 @@ class NeuralnetConfig(ConfigDict): """Configuration for a neural network model. For type hinting""" name: str output_size: Int + hidden_layer_sizes: list[int] layer_sizes: list[int] learning_rate: Float batch_size: Int @@ -28,7 +30,7 @@ class NeuralnetConfig(ConfigDict): def __init__(self, name: str = "MLP", - output_size: Int = 10, + output_size: int = 10, hidden_layer_sizes: list[int] = [64, 128, 64], learning_rate: Float = 1e-3, batch_size: int = 128, @@ -38,165 +40,36 @@ def __init__(self, super().__init__() self.name = name self.output_size = output_size - hidden_layer_sizes.append(self.output_size) - self.layer_sizes = hidden_layer_sizes + self.hidden_layer_sizes = hidden_layer_sizes + self.layer_sizes = [*hidden_layer_sizes, output_size] self.learning_rate = learning_rate self.batch_size = batch_size self.nb_epochs = nb_epochs if nb_report is None: nb_report = self.nb_epochs // 10 self.nb_report = nb_report - -##################### -### ARCHITECTURES ### -##################### - -class BaseNeuralnet(nn.Module): - """Abstract base class. Needs layer sizes and activation function used""" - layer_sizes: Sequence[int] - act_func: Callable = nn.relu - - def setup(self): - raise NotImplementedError - - def __call__(self, x): - raise NotImplementedError - -class MLP(BaseNeuralnet): - """Basic multi-layer perceptron: a feedforward neural network with multiple Dense layers.""" - - def setup(self): - self.layers = [nn.Dense(n) for n in self.layer_sizes] - @nn.compact - def __call__(self, x: Array): - """_summary_ +############# +### UTILS ### +############# - Args: - x (Array): Input data of the neural network. - """ - - for i, layer in enumerate(self.layers): - # Apply the linear part of the layer's operation - x = layer(x) - # If not the output layer, apply the given activation function - if i != len(self.layer_sizes) - 1: - x = self.act_func(x) - - return x - -################ -### TRAINING ### -################ - -def create_train_state(model: BaseNeuralnet, - test_input: Array, - rng: jax.random.PRNGKey, - config: NeuralnetConfig): +def kld(mean, logvar): """ - Creates an initial `TrainState` from NN model and optimizer and initializes the parameters by passing dummy input. - - Args: - model (BaseNeuralnet): Neural network model to be trained. - test_input (Array): A test input used to initialize the parameters of the model. - rng (jax.random.PRNGKey): Random number generator key used for initialization of the model. - config (NeuralnetConfig): Configuration for the neural network training. - - Returns: - TrainState: TrainState object for training + Kullback-Leibler divergence of a normal distribution with arbitrary mean and log variance to the standard normal distribution with mean 0 and unit variance. """ - params = model.init(rng, test_input)['params'] - tx = optax.adam(config.learning_rate) - state = TrainState.create(apply_fn = model.apply, params = params, tx = tx) - return state + return 0.5 * jnp.sum(mean**2 + jnp.exp(logvar) - logvar -1) -def apply_model(state: TrainState, - x_batched: Float[Array, "n_batch ndim_input"], - y_batched: Float[Array, "n_batch ndim_output"]): +def bce(y, pred): """ - Apply the model to a batch of data and compute the loss and gradients. - - Args: - state (TrainState): TrainState object for training. - x_batched (Float[Array, "n_batch ndim_input"]): Batch of input - y_batched (Float[Array, "n_batch ndim_output"]): Batch of output + binary cross entropy between y and the predicted array pred """ + return -jnp.sum(y * jnp.log(pred) + (1-y) * jnp.log(1-pred)) - def loss_fn(params): - def squared_error(x, y): - # For a single datapoint - pred = state.apply_fn({'params': params}, x) - return jnp.inner(y - pred, y - pred) / 2.0 - # Vectorize the previous to compute the average of the loss on all samples. - return jnp.mean(jax.vmap(squared_error)(x_batched, y_batched)) - - grad_fn = jax.value_and_grad(loss_fn) - loss, grads = grad_fn(state.params) - return loss, grads - -@jax.jit -def train_step(state: TrainState, - train_X: Float[Array, "n_batch_train ndim_input"], - train_y: Float[Array, "n_batch_train ndim_output"], - val_X: Float[Array, "n_batch_val ndim_output"] = None, - val_y: Float[Array, "n_batch_val ndim_output"] = None) -> tuple[TrainState, Float[Array, "n_batch_train"], Float[Array, "n_batch_val"]]: +def mse(y, pred): """ - Train for a single step. Note that this function is functionally pure and hence suitable for jit. - - Args: - state (TrainState): TrainState object - train_X (Float[Array, "n_batch_train ndim_input"]): Training input data - train_y (Float[Array, "n_batch_train ndim_output"]): Training output data - val_X (Float[Array, "n_batch_val ndim_input"], optional): Validation input data. Defaults to None. - val_y (Float[Array, "n_batch_val ndim_output"], optional): Valdiation output data. Defaults to None. - - Returns: - tuple[TrainState, Float, Float]: TrainState with updated weights, and arrays of training and validation losses + square error between y and the predicted array pred """ - - # Compute losses - train_loss, grads = apply_model(state, train_X, train_y) - if val_X is not None: - val_loss, _ = apply_model(state, val_X, val_y) - else: - val_loss = jnp.zeros_like(train_loss) - - # Update parameters - state = state.apply_gradients(grads=grads) - - return state, train_loss, val_loss - -def train_loop(state: TrainState, - config: NeuralnetConfig, - train_X: Float[Array, "n_batch_train ndim_input"], - train_y: Float[Array, "n_batch_train ndim_output"], - val_X: Float[Array, "n_batch_val ndim_output"] = None, - val_y: Float[Array, "n_batch_val ndim_output"] = None, - verbose: bool = True): - - train_losses, val_losses = [], [] - - start = time.time() - - for i in range(config.nb_epochs): - # Do a single step - - state, train_loss, val_loss = train_step(state, train_X, train_y, val_X, val_y) - # Save the losses - train_losses.append(train_loss) - val_losses.append(val_loss) - # Report once in a while - if i % config.nb_report == 0 and verbose: - print(f"Train loss at step {i+1}: {train_loss}") - print(f"Valid loss at step {i+1}: {val_loss}") - print(f"Learning rate: {config.learning_rate}") - print("---") - - end = time.time() - if verbose: - print(f"Training for {config.nb_epochs} took {end-start} seconds.") - - return state, train_losses, val_losses + return jnp.sum((y-pred)**2) def serialize(state: TrainState, config: NeuralnetConfig = None) -> dict: @@ -219,55 +92,266 @@ def serialize(state: TrainState, return serialized_dict -# TODO: add support for various activation functions and different model architectures to be loaded from serialized models -def save_model(state: TrainState, - config: ConfigDict = None, - out_name: str = "my_flax_model.pkl"): - """ - Serialize and save the model to a file. +################ +### TRAINING ### +################ + + +class CVAE: + def __init__(self, + config: NeuralnetConfig, + conditional_dim: Int, + key: jax.random.PRNGKey = jax.random.key(21)): + self.config = config + net = nn.CVAE(hidden_layer_sizes= config.hidden_layer_sizes, output_size= config.output_size) + key, subkey, subkey2 = jax.random.split(key, 3) + + params = net.init(subkey, jnp.ones(config.output_size), jnp.ones(conditional_dim), subkey2)['params'] + tx = optax.adam(config.learning_rate) + self.state = TrainState.create(apply_fn = net.apply, params = params, tx = tx) # initialize the training state + + @staticmethod + @jax.jit + def train_step(state: TrainState, + train_X: Float[Array, "n_batch_train ndim_input"], + train_y: Float[Array, "n_batch_train ndim_output"], + rng: jax.random.PRNGKey, + val_X: Float[Array, "n_batch_val ndim_output"] = None, + val_y: Float[Array, "n_batch_val ndim_output"] = None, + ) -> tuple[TrainState, Float[Array, "n_batch_train"], Float[Array, "n_batch_val"]]: + def apply_model(state, X, y, z_rng): + def loss_fn(params): + reconstructed_y, mean, logvar = state.apply_fn({'params': params}, y, X, z_rng) + mse_loss = jnp.mean(jax.vmap(mse)(y, reconstructed_y)) # mean squared error loss + kld_loss = jnp.mean(jax.vmap(kld)(mean, logvar)) # KLD loss + return mse_loss + kld_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + return loss, grads + rng, z_rng = jax.random.split(rng) + train_loss, grads = apply_model(state, train_X, train_y, z_rng) + if val_X is not None: + rng, z_rng = jax.random.split(rng) + val_loss, _ = apply_model(state, val_X, val_y, z_rng) + else: + val_loss = jnp.zeros_like(train_loss) + + # Update parameters + state = state.apply_gradients(grads=grads) + + return state, train_loss, val_loss, rng + + def train_loop(self, + train_X: Float[Array, "n_batch_train ndim_input"], + train_y: Float[Array, "n_batch_train ndim_output"], + val_X: Float[Array, "n_batch_val ndim_output"] = None, + val_y: Float[Array, "n_batch_val ndim_output"] = None, + verbose: bool = True): - Raises: - ValueError: If the provided file extension is not .pkl or .pickle. + train_losses, val_losses = [], [] + rng = jax.random.key(2025) + state = self.state + + start = time.time() + + for i in range(self.config.nb_epochs): + # Do a single step + rng, subkey = jax.random.split(rng) + state, train_loss, val_loss, rng = self.train_step(state, train_X, train_y, subkey, val_X, val_y) + # Save the losses + train_losses.append(train_loss) + val_losses.append(val_loss) + # Report once in a while + if i % self.config.nb_report == 0 and verbose: + print(f"Train loss at step {i+1}: {train_loss}") + print(f"Valid loss at step {i+1}: {val_loss}") + print(f"Learning rate: {self.config.learning_rate}") + print("---") + + end = time.time() + if verbose: + print(f"Training for {self.config.nb_epochs} took {end-start} seconds.") + + self.trained_state = state + + return self.trained_state, train_losses, val_losses + + def save_model(self, outfile: str = "my_flax_model.pkl"): + """ + Serialize and save the model to a file. + + Raises: + ValueError: If the provided file extension is not .pkl or .pickle. + + Args: + outfile (str, optional): The pickle file to which we save the serialized model. Defaults to "my_flax_model.pkl". + """ + + if not outfile.endswith(".pkl") and not outfile.endswith(".pickle"): + raise ValueError("For now, only .pkl or .pickle extensions are supported.") + + serialized_dict = serialize(self.trained_state, self.config) + with open(outfile, 'wb') as handle: + pickle.dump(serialized_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) - Args: - state (TrainState): The TrainState object to be saved. - config (ConfigDict, optional): The config to be saved.. Defaults to None. - out_name (str, optional): The pickle file to which we save the serialized model. Defaults to "my_flax_model.pkl". - """ + @staticmethod + def load_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: + """ + Load a model from a file. + TODO: this is very cumbersome now and must be massively improved in the future - if not out_name.endswith(".pkl") and not out_name.endswith(".pickle"): - raise ValueError("For now, only .pkl or .pickle extensions are supported.") + Args: + filename (str): Filename of the model to be loaded. - serialized_dict = serialize(state, config) - with open(out_name, 'wb') as handle: - pickle.dump(serialized_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) + Raises: + ValueError: If there is something wrong with loading, since lots of things can go wrong here. -def load_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: - """ - Load a model from a file. - TODO: this is very cumbersome now and must be massively improved in the future + Returns: + tuple[TrainState, NeuralnetConfig]: The TrainState object loaded from the file and the NeuralnetConfig object. + """ + with open(filename, 'rb') as handle: + loaded_dict = pickle.load(handle) + + config: NeuralnetConfig = loaded_dict["config"] + params = loaded_dict["params"] + + net = nn.Decoder(layer_sizes = [*config.hidden_layer_sizes[::-1], config.output_size]) + # Create train state without optimizer + state = TrainState.create(apply_fn = net.apply, params = params["decoder"], tx = optax.adam(config.learning_rate)) + + return state, config + + @staticmethod + def load_full_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: - Args: - filename (str): Filename of the model to be loaded. + with open(filename, "rb") as handle: + loaded_dict = pickle.load(handle) + + config: NeuralnetConfig = loaded_dict["config"] + params = loaded_dict["params"] - Raises: - ValueError: If there is something wrong with loading, since lots of things can go wrong here. + net = nn.CVAE(hidden_layer_sizes=config.hidden_layer_sizes, output_size= config.output_size) + # Create train state without optimizer + state = TrainState.create(apply_fn = net.apply, params = params, tx = optax.adam(config.learning_rate)) - Returns: - tuple[TrainState, NeuralnetConfig]: The TrainState object loaded from the file and the NeuralnetConfig object. - """ + return state, config + + +class MLP: + def __init__(self, + config: NeuralnetConfig, + input_ndim: Int, + key: jax.random.PRNGKey = jax.random.key(21)): + self.config = config + net = nn.MLP(layer_sizes= config.layer_sizes) + key, subkey = jax.random.split(key) + params = net.init(subkey, jnp.ones(input_ndim))['params'] + tx = optax.adam(config.learning_rate) + self.state = TrainState.create(apply_fn = net.apply, params = params, tx = tx) # initialize the training state + + @staticmethod + @jax.jit + def train_step(state: TrainState, + train_X: Float[Array, "n_batch_train ndim_input"], + train_y: Float[Array, "n_batch_train ndim_output"], + val_X: Float[Array, "n_batch_val ndim_output"] = None, + val_y: Float[Array, "n_batch_val ndim_output"] = None, + ) -> tuple[TrainState, Float[Array, "n_batch_train"], Float[Array, "n_batch_val"]]: + def apply_model(state, X, y): + def loss_fn(params): + reconstructed_y = state.apply_fn({'params': params}, X) + mse_loss = jnp.mean(jax.vmap(mse)(y, reconstructed_y)) # mean squared error loss + return mse_loss + + grad_fn = jax.value_and_grad(loss_fn) + loss, grads = grad_fn(state.params) + return loss, grads + train_loss, grads = apply_model(state, train_X, train_y) + if val_X is not None: + val_loss, _ = apply_model(state, val_X, val_y) + else: + val_loss = jnp.zeros_like(train_loss) + + # Update parameters + state = state.apply_gradients(grads=grads) + + return state, train_loss, val_loss + + def train_loop(self, + train_X: Float[Array, "n_batch_train ndim_input"], + train_y: Float[Array, "n_batch_train ndim_output"], + val_X: Float[Array, "n_batch_val ndim_output"] = None, + val_y: Float[Array, "n_batch_val ndim_output"] = None, + verbose: bool = True): + + train_losses, val_losses = [], [] + state = self.state + + start = time.time() + + for i in range(self.config.nb_epochs): + # Do a single step + state, train_loss, val_loss = self.train_step(state, train_X, train_y, val_X, val_y) + # Save the losses + train_losses.append(train_loss) + val_losses.append(val_loss) + # Report once in a while + if i % self.config.nb_report == 0 and verbose: + print(f"Train loss at step {i+1}: {train_loss}") + print(f"Valid loss at step {i+1}: {val_loss}") + print(f"Learning rate: {self.config.learning_rate}") + print("---") + + end = time.time() + if verbose: + print(f"Training for {self.config.nb_epochs} took {end-start} seconds.") + + self.trained_state = state - with open(filename, 'rb') as handle: - loaded_dict = pickle.load(handle) + return self.trained_state, train_losses, val_losses + + def save_model(self, outfile: str = "my_flax_model.pkl"): + """ + Serialize and save the model to a file. + + Raises: + ValueError: If the provided file extension is not .pkl or .pickle. + + Args: + outfile (str, optional): The pickle file to which we save the serialized model. Defaults to "my_flax_model.pkl". + """ - config: NeuralnetConfig = loaded_dict["config"] - layer_sizes = config.layer_sizes - act_func = nn.relu - params = loaded_dict["params"] + if not outfile.endswith(".pkl") and not outfile.endswith(".pickle"): + raise ValueError("For now, only .pkl or .pickle extensions are supported.") - model = MLP(layer_sizes, act_func) + serialized_dict = serialize(self.trained_state, self.config) + with open(outfile, 'wb') as handle: + pickle.dump(serialized_dict, handle, protocol=pickle.HIGHEST_PROTOCOL) + + @staticmethod + def load_model(filename: str) -> tuple[TrainState, NeuralnetConfig]: + """ + Load a model from a file. + TODO: this is very cumbersome now and must be massively improved in the future + + Args: + filename (str): Filename of the model to be loaded. + + Raises: + ValueError: If there is something wrong with loading, since lots of things can go wrong here. - # Create train state without optimizer - state = TrainState.create(apply_fn = model.apply, params = params, tx = optax.adam(config.learning_rate)) + Returns: + tuple[TrainState, NeuralnetConfig]: The TrainState object loaded from the file and the NeuralnetConfig object. + """ + with open(filename, 'rb') as handle: + loaded_dict = pickle.load(handle) + + config: NeuralnetConfig = loaded_dict["config"] + params = loaded_dict["params"] + + net = nn.MLP(config.layer_sizes) + # Create train state without optimizer + state = TrainState.create(apply_fn = net.apply, params = params, tx = optax.adam(config.learning_rate)) - return state, config + return state, config \ No newline at end of file diff --git a/src/fiesta/train/nn_architectures.py b/src/fiesta/train/nn_architectures.py new file mode 100644 index 0000000..087d7bc --- /dev/null +++ b/src/fiesta/train/nn_architectures.py @@ -0,0 +1,142 @@ +from typing import Sequence, Callable + +import jax +import jax.numpy as jnp +from jaxtyping import Array, Int +from flax import linen as nn # Linen API + + + +##################### +### ARCHITECTURES ### +##################### + +class BaseNeuralnet(nn.Module): + """Abstract base class. Needs layer sizes and activation function used""" + layer_sizes: Sequence[int] + act_func: Callable[[Array], Array] = nn.relu + + def setup(self): + raise NotImplementedError + + def __call__(self, x): + raise NotImplementedError + +class MLP(BaseNeuralnet): + """Basic multi-layer perceptron: a feedforward neural network with multiple Dense layers.""" + + def setup(self): + self.layers = [nn.Dense(n) for n in self.layer_sizes] + + @nn.compact + def __call__(self, x: Array): + """_summary_ + + Args: + x (Array): Input data of the neural network. + """ + + for layer in self.layers[:-1]: + # Apply the linear part of the layer's operation + x = layer(x) + # Apply the given activation function + x = self.act_func(x) + + x = self.layers[-1](x) # for the output layer only apply the linear part + return x + +class Encoder(nn.Module): + layer_sizes: Sequence[int] + act_func: Callable[[Array], Array] = nn.relu + + def setup(self): + self.mu_layers = [nn.Dense(n) for n in self.layer_sizes] + self.logvar_layers = [nn.Dense(n) for n in self.layer_sizes] + + @nn.compact + def __call__(self, y: Array): + + mu = y.copy() + for layer in self.mu_layers[:-1]: + mu = layer(mu) + mu = self.act_func(mu) + mu = self.mu_layers[-1](mu) + + logvar = y.copy() + for layer in self.logvar_layers[:-1]: + logvar = layer(logvar) + logvar = self.act_func(logvar) + logvar = self.logvar_layers[-1](logvar) + return mu, logvar + +class Decoder(MLP): + + @nn.compact + def __call__(self, z: Array): + for layer in self.layers[:-1]: + # Apply the linear part of the layer's operation + z = layer(z) + # Apply the given activation function + z = self.act_func(z) + + z = self.layers[-1](z) # for the output layer only apply the linear part + return z + + +class CVAE(nn.Module): + """Conditional Variational Autoencoder consisting of an Encoder and a Decoder.""" + hidden_layer_sizes: Sequence[Int] # used for both the encoder and decoder + output_size: Int + z_dim: Int = 20 + + def setup(self): + self.encoder = Encoder([*self.hidden_layer_sizes, self.z_dim]) + self.decoder = Decoder(layer_sizes = [*self.hidden_layer_sizes[::-1], self.output_size], act_func=nn.relu) + + def __call__(self, y: Array, x: Array, z_rng: jax.random.PRNGKey): + y = jnp.concatenate([y, x.copy()], axis = -1) + mu, logvar = self.encoder(y) + + # Reparametrize + std = jnp.exp(0.5* logvar) + eps = jax.random.normal(z_rng, logvar.shape) + z = mu + eps * std + + z_x = jnp.concatenate([z, x.copy()], axis = -1) + reconstructed_y = self.decoder(z_x) + return reconstructed_y, mu, logvar + +class CNN(nn.Module): + """Convolutional Neural Network""" + dense_layer_sizes: Sequence[Int] + kernel_sizes: Sequence[Int] + conv_layer_sizes: Sequence[Int] + output_shape: tuple[Int, Int] + spatial: Int = 32 + act_func: Callable[[Array], Array] = nn.relu + + def setup(self): + if self.dense_layer_sizes[-1] != self.conv_layer_sizes[0]: + raise ValueError(f"Final dense layer must be equally large as first convolutional layer.") + if self.conv_layer_sizes[-1] != 1: + raise ValueError(f"Last convolutional layer must be of size 1 to predict 2D array.") + + self.dense_layers = [nn.Dense(n) for n in self.dense_layer_sizes[:-1]] + self.dense_layers += (nn.Dense(self.dense_layer_sizes[-1] * self.spatial**2), ) # the last dense layer should create an array that can be reshaped into spatial and chanel parts + self.conv_layers = [nn.Conv(features = f, kernel_size = (k,k)) for f, k in zip(self.conv_layer_sizes, self.kernel_sizes)] + + def __call__(self, x: Array): + # Apply the dense layers + for layer in self.dense_layers: + x = layer(x) + x = self.act_func(x) + + x = x.reshape((-1, self.spatial, self.spatial, self.dense_layer_sizes[-1])) + for layer in self.conv_layers[:-1]: + x = layer(x) + x = self.act_func(x) + + x = self.conv_layers[-1](x) # only apply convolution part of last convolutional layer + x = x[:,:,:,0] + x = jax.image.resize(x, shape = (x.shape[0], *self.output_shape), method = "bilinear") # resize the NN output to the desired output + return x diff --git a/src/fiesta/utils.py b/src/fiesta/utils.py index 9b3347f..c408799 100644 --- a/src/fiesta/utils.py +++ b/src/fiesta/utils.py @@ -1,17 +1,21 @@ -import jax.numpy as jnp -from jax.scipy.stats import truncnorm -from jaxtyping import Array, Float -import jax -import numpy as np -import pandas as pd -import scipy.interpolate as interp import copy import re + +import numpy as np +import pandas as pd from astropy.time import Time -import astropy -import scipy from sncosmo.bandpasses import _BANDPASSES, _BANDPASS_INTERPOLATORS -import sncosmo +from sncosmo import get_bandpass +import scipy.interpolate as interp + +import jax.numpy as jnp +from jax.scipy.stats import truncnorm +from jaxtyping import Array, Float, Int +import jax + +from fiesta.conversions import monochromatic_AB_mag, bandpass_AB_mag, integrated_AB_mag +import fiesta.constants as constants + #################### @@ -46,7 +50,6 @@ def fit_transform(self, x: Array) -> Array: self.fit(x) return self.transform(x) - class StandardScalerJax(object): """ StandardScaler like sklearn does it, but for JAX arrays since sklearn might not be JAX-compatible? @@ -75,7 +78,7 @@ def fit_transform(self, x: Array) -> Array: self.fit(x) return self.transform(x) -class PCAdecomposer(object): +class PCADecomposer(object): """ PCA decomposer like sklearn does it. Based on https://github.com/alonfnt/pcax/tree/main. """ @@ -85,7 +88,7 @@ def __init__(self, n_components: int, solver: str = "randomized"): def fit(self, x: Array)-> None: if self.solver == "full": - self._fit_full(x, self.n_components) + self._fit_full(x) elif self.solver == "randomized": rng = jax.random.PRNGKey(self.n_components) self._fit_randomized(x, rng) @@ -141,7 +144,90 @@ def inverse_transform(self, x: Array)->Array: def fit_transform(self, x: Array)-> Array: self.fit(x) return self.transform(x) + +class SVDDecomposer(object): + """ + SVDDecomposer that uses the old NMMA approach to decompose lightcurves into SVD coefficients. + """ + def __init__(self, + svd_ncoeff: Int): + self.svd_ncoeff = svd_ncoeff + self.scaler = MinMaxScalerJax() + + def fit(self, x: Array): + xcopy = x.copy() + xcopy = self.scaler.fit_transform(xcopy) + + # Do SVD decomposition on the training data + UA, _, VA = jnp.linalg.svd(xcopy, full_matrices=True) + self.VA = VA[:self.svd_ncoeff] + + def transform(self, x: Array) -> Array: + x = self.scaler.transform(x) + x = jnp.dot(x, self.VA.T) + return x + + def inverse_transform(self, x: Array) -> Array: + x = jnp.dot(x, self.VA) + x = self.scaler.inverse_transform(x) + return x + + def fit_transform(self, x: Array)-> Array: + self.fit(x) + return self.transform(x) + +class ImageScaler(object): + """ + Scaler that down samples 2D arrays of shape upscale to downscale and the inverse. + Note that the methods always assume that the input array x is flattened along the last axis, i.e. it will reshape the input x.reshape(-1, *upscale). + The down sampled image is scaled once more with a scaler object. + Attention, this object has no proper fit method, because of its application in FluxTrainerCVAE and the way the data is loaded there to avoid memory issues. + """ + def __init__(self, + downscale: Int[Array, "shape=(2,)"], + upscale: Int[Array, "shape=(2,)"], + scaler: object): + self.downscale = downscale + self.upscale = upscale + self.scaler = scaler + + def resize_image(self, x: Array) -> Array: + x = x.reshape(-1, *self.upscale) + return jax.image.resize(x, shape = (x.shape[0], *self.downscale), method = "cubic") + + def transform(self, x: Array)-> Array: + x = x.reshape(-1, *self.upscale) + x = jax.image.resize(x, shape = (x.shape[0], *self.downscale), method = "cubic") + x = x.reshape(-1, jnp.prod(self.downscale)) + x = self.scaler.transform(x) + return x + + def inverse_transform(self, x: Array)-> Array: + x = self.scaler.inverse_transform(x) + x = x.reshape(-1, *self.downscale) + x = jax.image.resize(x, shape = (x.shape[0], *self.upscale), method = "cubic") + out = jax.vmap(self.fix_edges)(x[:, :, 4:-4]) # this is necessary because jax.image.resize produces artefacts at the edges when upsampling + return out + + def fit_transform_scaler(self, x: Array) -> Array: + """Method that will fit the scaling object. Here, the array already has to be down sampled.""" + out = self.scaler.fit_transform(x) + return out + @staticmethod + @jax.vmap + def fix_edges(yp: Array): + """Extrapolate at early and late times from the reconstructed array to avoid artefacts at the edges from jax.image.resize.""" + xp = jnp.arange(4, yp.shape[0]+4) + xl = jnp.arange(0,4) + xr = jnp.arange(yp.shape[0]+4, yp.shape[0]+8) + yl = jnp.interp(xl, xp, yp, left = "extrapolate", right = "extrapolate") + yr = jnp.interp(xr, xp, yp, left = "extrapolate", right = "extrapolate") + out = jnp.concatenate([yl, yp, yr]) + return out + + +# TODO: Remove this def inverse_svd_transform(x: Array, VA: Array, nsvd_coeff: int = 10) -> Array: @@ -154,6 +240,8 @@ def inverse_svd_transform(x: Array, ### BULLA UTILITIES ### ####################### +# TODO: place that somewhere else? + def get_filters_bulla_file(filename: str, drop_times: bool = False) -> list[str]: @@ -298,29 +386,39 @@ def truncated_gaussian(mag_det: Array, return logpdf def load_event_data(filename): - # TODO: polish? - lines = [line.rstrip("\n") for line in open(filename)] - lines = filter(None, lines) + """ + Takes a file and outputs a magnitude dict with filters as keys. + + Args: + filename (str): path to file to be read in + + Returns: + data (dict[str, Array]): Data dictionary with filters as keys. The array has the structure [[mjd, mag, err]]. - sncosmo_filts = [val["name"] for val in _BANDPASSES.get_loaders_metadata()] - sncosmo_maps = {name: name.replace(":", "_") for name in sncosmo_filts} + """ + mjd, filters, mags, mag_errors = [], [], [], [] - data = {} - for line in lines: - lineSplit = line.split(" ") - lineSplit = list(filter(None, lineSplit)) - mjd = Time(lineSplit[0], format="isot").mjd - filt = lineSplit[1] + with open(filename, "r") as input: - if filt in sncosmo_maps: - filt = sncosmo_maps[filt] + for line in input: + line = line.rstrip("\n") + t, filter, mag, mag_err = line.split(" ") - mag = float(lineSplit[2]) - dmag = float(lineSplit[3]) + mjd.append(Time(t, format="isot").mjd) # convert to mjd + filters.append(filter) + mags.append(float(mag)) + mag_errors.append(float(mag_err)) + + mjd = np.array(mjd) + filters = np.array(filters) + mags = np.array(mags) + mag_errors = np.array(mag_errors) + data = {} - if filt not in data: - data[filt] = np.empty((0, 3), float) - data[filt] = np.append(data[filt], np.array([[mjd, mag, dmag]]), axis=0) + unique_filters = np.unique(filters) + for filt in unique_filters: + filt_inds = np.where(filters==filt)[0] + data[filt] = np.array([ mjd[filt_inds], mags[filt_inds], mag_errors[filt_inds] ]).T return data @@ -346,155 +444,80 @@ class Filter: def __init__(self, name: str,): + """ + Filter class that uses the bandpass properties from sncosmo or just a simple monochromatic filter based on the name. + The necessary attributes are stored as jnp arrays. + + Args: + name (str): Name of the filter. Will be either passed to sncosmo to get the optical bandpass, or the unit at the end will be used to create a monochromatic filter. Supported units are keV and GHz. + """ self.name = name if (self.name, None) in _BANDPASSES._primary_loaders: - bandpass = sncosmo.get_bandpass(self.name) - self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10) + bandpass = get_bandpass(self.name) # sncosmo bandpass + self.nu = constants.c / (bandpass.wave_eff*1e-10) + self.nus = constants.c / (bandpass.wave[::-1]*1e-10) + self.trans = bandpass.trans[::-1] # reverse the array to get the transmission as function of frequency (not wavelength) + self.filt_type = "bandpass" + elif (self.name, None) in _BANDPASS_INTERPOLATORS._primary_loaders: - # FIXME: val undefined - bandpass = sncosmo.get_bandpass(val["name"], 3) - self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10) + bandpass = get_bandpass(self.name, 0) # these bandpass interpolators require a radius (here by default 0 cm) + self.nu = constants.c/(bandpass.wave_eff*1e-10) + self.nus = constants.c / (bandpass.wave[::-1]*1e-10) + self.trans = bandpass.trans[::-1] # reverse the array to get the transmission as function of frequency (not wavelength) + self.filt_type = "bandpass" + elif self.name.endswith("GHz"): freq = re.findall(r"[-+]?(?:\d*\.*\d+)", self.name.replace("-","")) freq = float(freq[-1]) self.nu = freq*1e9 + self.nus = jnp.array([self.nu]) + self.trans = jnp.ones(1) + self.filt_type = "monochromatic" + elif self.name.endswith("keV"): energy = re.findall(r"[-+]?(?:\d*\.*\d+)", self.name.replace("-","")) energy = float(energy[-1]) - self.nu = energy*1000*scipy.constants.eV / scipy.constants.h - else: - print(f"Warning: Filter {self.name} not recognized") - self.nu = jnp.nan - - self.wavelength = scipy.constants.c/self.nu + self.nu = energy*1000*constants.eV / constants.h + self.nus = jnp.array([self.nu]) + self.trans = jnp.ones(1) + self.filt_type = "monochromatic" + + elif self.name.startswith("XRT"): + energy1, energy2 = re.findall(r"\d+\.\d+|\d+", self.name) + nu1 = float(energy1)*1000*constants.eV / constants.h + nu2 = float(energy2)*1000*constants.eV / constants.h + self.nus = jnp.linspace(nu1, nu2, 10) + self.trans = jnp.ones_like(self.nus) + self.nu = jnp.mean(self.nus) + self.filt_type = "integrated" - -def get_all_bandpass_metadata(): - # TODO: taken over from NMMA, improve - """ - Retrieves and combines the metadata for all registered bandpasses and interpolators. - - Returns: - list: Combined list of metadata dictionaries from bandpasses and interpolators for sncosmo. - """ - - bandpass_metadata = _BANDPASSES.get_loaders_metadata() - interpolator_metadata = _BANDPASS_INTERPOLATORS.get_loaders_metadata() - - combined_metadata = bandpass_metadata + interpolator_metadata - - return combined_metadata - -def get_default_filts_lambdas(filters: list[str]=None): - - filts = [ - "u", - "g", - "r", - "i", - "z", - "y", - "J", - "H", - "K", - "U", - "B", - "V", - "R", - "I", - "radio-1.25GHz", - "radio-3GHz", - "radio-5.5GHz", - "radio-6GHz", - "X-ray-1keV", - "X-ray-5keV", - ] - lambdas_sloan = 1e-10 * np.array( - [3561.8, 4866.46, 6214.6, 7687.0, 7127.0, 7544.6, 8679.5, 9633.3, 12350.0] - ) - lambdas_bessel = 1e-10 * np.array([3605.07, 4413.08, 5512.12, 6585.91, 8059.88]) - lambdas_radio = scipy.constants.c / np.array([1.25e9, 3e9, 5.5e9, 6e9]) - lambdas_Xray = scipy.constants.c / ( - np.array([1e3, 5e3]) * scipy.constants.eV / scipy.constants.h - ) - - lambdas = np.concatenate( - [lambdas_sloan, lambdas_bessel, lambdas_radio, lambdas_Xray] - ) - - bandpasses = [] - for val in get_all_bandpass_metadata(): - if val["name"] in [ - "ultrasat", - "megacampsf::u", - "megacampsf::g", - "megacampsf::r", - "megacampsf::i", - "megacampsf::z", - "megacampsf::y", - ]: - bandpass = sncosmo.get_bandpass(val["name"], 3) - bandpass.name = bandpass.name.split()[0] else: - bandpass = sncosmo.get_bandpass(val["name"]) - - bandpasses.append(bandpass) - - filts = filts + [band.name for band in bandpasses] - lambdas = np.concatenate([lambdas, [1e-10 * band.wave_eff for band in bandpasses]]) - - if filters is not None: - filts_slice = [] - lambdas_slice = [] - transmittance_slice = [] - - for filt in filters: - if filt.startswith("radio") and filt not in filts: - # for additional radio filters that not in the list - # calculate the lambdas based on the filter name - # split the filter name - freq_string = filt.replace("radio-", "") - freq_unit = freq_string[-3:] - freq_val = float(freq_string.replace(freq_unit, "")) - # make use of the astropy.units to be more flexible - freq = astropy.units.Quantity(freq_val, unit=freq_unit) - freq = freq.to("Hz").value - # adding to the list - filts_slice.append(filt) - lambdas_slice.append([scipy.constants.c / freq]) - transmittance_slice.append([1]) - - elif filt.startswith("X-ray-") and filt not in filts: - # for additional X-ray filters that not in the list - # calculate the lambdas based on the filter name - # split the filter name - energy_string = filt.replace("X-ray-", "") - energy_unit = energy_string[-3:] - energy_val = float(energy_string.replace(energy_unit, "")) - # make use of the astropy.units to be more flexible - energy = astropy.units.Quantity(energy_val, unit=energy_unit) - freq = energy.to("eV").value * scipy.constants.eV / scipy.constants.h - # adding to the list - filts_slice.append(filt) - lambdas_slice.append([scipy.constants.c / freq]) - transmittance_slice.append([1]) - - else: - try: - ii = filts.index(filt) - filts_slice.append(filts[ii]) - lambdas_slice.append([lambdas[ii]]) - except ValueError: - ii = filts.index(filt.replace("_", ":")) - filts_slice.append(filts[ii].replace(":", "_")) - lambdas_slice.append(lambdas[ii]) - - filts = filts_slice - lambdas = np.array(lambdas_slice) - - # FIXME: transmittance undefined - return filts, lambdas, transmittance - - -def mJys_to_mag(): - pass \ No newline at end of file + raise ValueError(f"Filter {self.name} not recognized") + + self.wavelength = constants.c/self.nu*1e10 + self._calculate_ref_flux() + + if self.filt_type=="bandpass": + self.get_mag = lambda Fnu, nus: bandpass_AB_mag(Fnu, nus, self.nus, self.trans, self.ref_flux) + elif self.filt_type=="monochromatic": + self.get_mag = lambda Fnu, nus: monochromatic_AB_mag(Fnu, nus, self.nus, self.trans, self.ref_flux) + elif self.filt_type=="integrated": + self.get_mag = lambda Fnu, nus: integrated_AB_mag(Fnu, nus, self.nus, self.trans) + + + def _calculate_ref_flux(self,): + """method to determine the reference flux for the magnitude conversion.""" + if self.filt_type in ["monochromatic", "integrated"]: + self.ref_flux = 3631000. # mJy + elif self.filt_type=="bandpass": + integrand = self.trans / (constants.h_erg_s * self.nus) # https://en.wikipedia.org/wiki/AB_magnitude + integral = jnp.trapezoid(y = integrand, x = self.nus) + self.ref_flux = 3631000. * integral.item() # mJy + + def get_mags(self, fluxes: Float[Array, "n_samples n_nus n_times"], nus: Float[Array, "n_nus"]) -> Float[Array, "n_samples n_times"]: + + def get_single(flux): + return self.get_mag(flux, nus) + + mags = jax.vmap(get_single)(fluxes) + return mags \ No newline at end of file diff --git a/tests/data/test_raw_data.h5 b/tests/data/test_raw_data.h5 new file mode 100644 index 0000000..001996c Binary files /dev/null and b/tests/data/test_raw_data.h5 differ diff --git a/tests/models/flux.pkl b/tests/models/flux.pkl new file mode 100644 index 0000000..1381a6d Binary files /dev/null and b/tests/models/flux.pkl differ diff --git a/tests/models/flux_metadata.pkl b/tests/models/flux_metadata.pkl new file mode 100644 index 0000000..655f8f0 Binary files /dev/null and b/tests/models/flux_metadata.pkl differ diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..27759c2 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,23 @@ +import os + +from fiesta.inference.lightcurve_model import AfterglowFlux + + +############## +# Flux model # +############## + +working_dir = os.path.dirname(__file__) +model_dir = os.path.join(working_dir, "models") + +def test_models(): + + model = AfterglowFlux(name="flux", + directory=model_dir, + filters=["radio-6GHz", "bessellv", "X-ray-1keV"], + model_type="MLP") + + X = [3.141/30, 54., 0.05, -1., 2.5, -2., -4.] + mag = model.predict_abs_mag(dict(zip(model.parameter_names, X))) + +# TODO: Add more model types here \ No newline at end of file diff --git a/tests/test_training.py b/tests/test_training.py new file mode 100644 index 0000000..8e97f22 --- /dev/null +++ b/tests/test_training.py @@ -0,0 +1,64 @@ +import os +from pathlib import Path + +from fiesta.train.FluxTrainer import PCATrainer +from fiesta.train.neuralnets import NeuralnetConfig + + +############# +### SETUP ### +############# + +tmin = 1 # days +tmax = 100 + +numin = 3e9 # Hz +numax = 1e15 + +n_training = 200 +n_val = 20 +n_pca = 10 + +working_dir = os.path.dirname(__file__) +file = os.path.join(working_dir, "data/test_raw_data.h5") + + +config = NeuralnetConfig(output_size=n_pca, + nb_epochs=10, + hidden_layer_sizes = [10], + learning_rate =1e-3) + + +############### +### TRAINER ### +############### + + +data_manager_args = dict(file=file, + n_training=n_training, + n_val=n_val, + tmin=tmin, + tmax=tmax, + numin=numin, + numax=numax, + special_training=["01"]) + + + + +def test_train_and_save(): + + name = "tophat" + + trainer = PCATrainer(name, + working_dir, + data_manager_args = data_manager_args, + n_pca = n_pca, + save_preprocessed_data=False + ) + + trainer.fit(config=config) + trainer.save() + + for file in Path(working_dir).glob("*.pkl"): + file.unlink() # Deletes the files \ No newline at end of file