diff --git a/setup.cfg b/setup.cfg index d1e24e6..7a1ccbb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -15,8 +15,8 @@ packages_dir= =src packages = find: install_requires = - jax>=0.4.24 - jaxlib>=0.4.24 + jax<=0.4.31 + jaxlib<=0.4.31 numpy<2.0.0 pandas<2.0.0 jaxtyping diff --git a/src/.gitignore b/src/.gitignore new file mode 100644 index 0000000..943c5f7 --- /dev/null +++ b/src/.gitignore @@ -0,0 +1 @@ +fiestaEM.egg-info/ diff --git a/src/fiesta.egg-info/PKG-INFO b/src/fiesta.egg-info/PKG-INFO deleted file mode 100644 index 05d5959..0000000 --- a/src/fiesta.egg-info/PKG-INFO +++ /dev/null @@ -1,54 +0,0 @@ -Metadata-Version: 2.1 -Name: fiesta -Version: 0.0.1 -Summary: Fast inference of electromagnetic signals with JAX -Home-page: https://github.com/thibeauwouters/fiesta -Author: Thibeau Wouters -Author-email: thibeauwouters@gmail.com -License: MIT -Keywords: sampling,inference,astrophysics,kilonovae,gamma-ray bursts -Requires-Python: >=3.10 -Description-Content-Type: text/markdown -License-File: LICENSE -Requires-Dist: jax>=0.4.24 -Requires-Dist: jaxlib>=0.4.24 -Requires-Dist: numpy<2.0.0 -Requires-Dist: pandas<2.0.0 -Requires-Dist: jaxtyping -Requires-Dist: beartype -Requires-Dist: tqdm -Requires-Dist: scipy<=1.14.0 -Requires-Dist: ml_collections -Requires-Dist: astropy -Requires-Dist: sncosmo -Requires-Dist: flowMC -Requires-Dist: joblib - -# fiesta 🎉 - -`fiesta`: **F**ast **I**nference of **E**lectromagnetic **S**ignals and **T**ransients with j**A**x - -![fiesta logo](docs/fiesta_logo.jpeg) - -**NOTE:** `fiesta` is currently under development -- stay tuned! - -## Installation - -pip installation is currently work in progress. Install from source by cloning this Github repository and running -``` -pip install -e . -``` - -NOTE: This is using an older and custom version of `flowMC`. Install by cloning the `flowMC` version at [this fork](https://github.com/ThibeauWouters/flowMC/tree/fiesta) (branch `fiesta`). - -## Training surrogate models - -To train your own surrogate models, have a look at some of the example scripts in the repository for inspiration, under `trained_models` - -- `train_Bu2019lm.py`: Example script showing how to train a surrogate model for the POSSIS `Bu2019lm` kilonova model. -- `train_afterglowpy_tophat.py`: Example script showing how to train a surrogate model for `afterglowpy`, using a tophat jet structure. - -## Examples - -- `run_AT2017gfo_Bu2019lm.py`: Example where we infer the parameters of the AT2017gfo kilonova with the `Bu2019lm` model. -- `run_GRB170817_tophat.py`: Example where we infer the parameters of the GRB170817 GRB with a surrogate model for `afterglowpy`'s tophat jet. **NOTE** This currently only uses one specific filter. The complete inference will be updated soon. diff --git a/src/fiesta.egg-info/SOURCES.txt b/src/fiesta.egg-info/SOURCES.txt deleted file mode 100644 index ce22b23..0000000 --- a/src/fiesta.egg-info/SOURCES.txt +++ /dev/null @@ -1,9 +0,0 @@ -LICENSE -README.md -pyproject.toml -setup.cfg -src/fiesta.egg-info/PKG-INFO -src/fiesta.egg-info/SOURCES.txt -src/fiesta.egg-info/dependency_links.txt -src/fiesta.egg-info/requires.txt -src/fiesta.egg-info/top_level.txt \ No newline at end of file diff --git a/src/fiesta.egg-info/dependency_links.txt b/src/fiesta.egg-info/dependency_links.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/fiesta.egg-info/dependency_links.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/fiesta.egg-info/requires.txt b/src/fiesta.egg-info/requires.txt deleted file mode 100644 index 283afba..0000000 --- a/src/fiesta.egg-info/requires.txt +++ /dev/null @@ -1,13 +0,0 @@ -jax>=0.4.24 -jaxlib>=0.4.24 -numpy<2.0.0 -pandas<2.0.0 -jaxtyping -beartype -tqdm -scipy<=1.14.0 -ml_collections -astropy -sncosmo -flowMC -joblib diff --git a/src/fiesta.egg-info/top_level.txt b/src/fiesta.egg-info/top_level.txt deleted file mode 100644 index 8b13789..0000000 --- a/src/fiesta.egg-info/top_level.txt +++ /dev/null @@ -1 +0,0 @@ - diff --git a/src/fiesta/constants.py b/src/fiestaEM/constants.py similarity index 100% rename from src/fiesta/constants.py rename to src/fiestaEM/constants.py diff --git a/src/fiesta/conversions.py b/src/fiestaEM/conversions.py similarity index 94% rename from src/fiesta/conversions.py rename to src/fiestaEM/conversions.py index 33a12c6..0b55824 100644 --- a/src/fiesta/conversions.py +++ b/src/fiestaEM/conversions.py @@ -1,4 +1,4 @@ -from fiesta.constants import pc_to_cm +from fiestaEM.constants import pc_to_cm import jax import jax.numpy as jnp from jaxtyping import Array diff --git a/src/fiesta/inference/fiesta.py b/src/fiestaEM/inference/fiesta.py similarity index 98% rename from src/fiesta/inference/fiesta.py rename to src/fiestaEM/inference/fiesta.py index 078dc81..94eea9f 100644 --- a/src/fiesta/inference/fiesta.py +++ b/src/fiestaEM/inference/fiesta.py @@ -7,10 +7,10 @@ import jax.numpy as jnp from jaxtyping import Float, Array, PRNGKeyArray -from fiesta.inference.lightcurve_model import LightcurveModel -from fiesta.inference.prior import Prior -from fiesta.inference.likelihood import EMLikelihood -from fiesta.conversions import mag_app_from_mag_abs +from fiestaEM.inference.lightcurve_model import LightcurveModel +from fiestaEM.inference.prior import Prior +from fiestaEM.inference.likelihood import EMLikelihood +from fiestaEM.conversions import mag_app_from_mag_abs from flowMC.sampler.Sampler import Sampler from flowMC.sampler.MALA import MALA diff --git a/src/fiesta/inference/injection.py b/src/fiestaEM/inference/injection.py similarity index 96% rename from src/fiesta/inference/injection.py rename to src/fiestaEM/inference/injection.py index f62dfac..40754b1 100644 --- a/src/fiesta/inference/injection.py +++ b/src/fiestaEM/inference/injection.py @@ -8,11 +8,11 @@ import jax.numpy as jnp from jaxtyping import Float, Array -from fiesta.inference.lightcurve_model import LightcurveModel -from fiesta.conversions import mag_app_from_mag_abs -from fiesta.utils import Filter -from fiesta.constants import days_to_seconds, c -from fiesta import conversions +from fiestaEM.inference.lightcurve_model import LightcurveModel +from fiestaEM.conversions import mag_app_from_mag_abs +from fiestaEM.utils import Filter +from fiestaEM.constants import days_to_seconds, c +from fiestaEM import conversions import afterglowpy as grb diff --git a/src/fiesta/inference/lightcurve_model.py b/src/fiestaEM/inference/lightcurve_model.py similarity index 97% rename from src/fiesta/inference/lightcurve_model.py rename to src/fiestaEM/inference/lightcurve_model.py index 4dfd7a0..a3481cf 100644 --- a/src/fiesta/inference/lightcurve_model.py +++ b/src/fiestaEM/inference/lightcurve_model.py @@ -11,10 +11,10 @@ from flax.training.train_state import TrainState import pickle -import fiesta.train.neuralnets as fiesta_nn -from fiesta.utils import MinMaxScalerJax, inverse_svd_transform -import fiesta.conversions as conversions -from fiesta import models_utilities +import fiestaEM.train.neuralnets as fiestaEM_nn +from fiestaEM.utils import MinMaxScalerJax, inverse_svd_transform +import fiestaEM.conversions as conversions +from fiestaEM import models_utilities ######################## ### ABSTRACT CLASSES ### @@ -184,7 +184,7 @@ def load_networks(self) -> None: self.models = {} for filter in self.filters: filename = os.path.join(self.directory, f"{filter}.pkl") - state, _ = fiesta_nn.load_model(filename) + state, _ = fiestaEM_nn.load_model(filename) self.models[filter] = state def load_parameter_names(self) -> None: diff --git a/src/fiesta/inference/likelihood.py b/src/fiestaEM/inference/likelihood.py similarity index 98% rename from src/fiesta/inference/likelihood.py rename to src/fiestaEM/inference/likelihood.py index 75023ec..db5d8b1 100644 --- a/src/fiesta/inference/likelihood.py +++ b/src/fiestaEM/inference/likelihood.py @@ -6,9 +6,9 @@ from jaxtyping import Float, Array import jax.numpy as jnp -from fiesta.inference.lightcurve_model import LightcurveModel -from fiesta.utils import truncated_gaussian -from fiesta.conversions import mag_app_from_mag_abs +from fiestaEM.inference.lightcurve_model import LightcurveModel +from fiestaEM.utils import truncated_gaussian +from fiestaEM.conversions import mag_app_from_mag_abs class EMLikelihood: diff --git a/src/fiesta/inference/prior.py b/src/fiestaEM/inference/prior.py similarity index 100% rename from src/fiesta/inference/prior.py rename to src/fiestaEM/inference/prior.py diff --git a/src/fiesta/models_utilities.py b/src/fiestaEM/models_utilities.py similarity index 100% rename from src/fiesta/models_utilities.py rename to src/fiestaEM/models_utilities.py diff --git a/src/fiesta/train/Benchmarker.py b/src/fiestaEM/train/Benchmarker.py similarity index 97% rename from src/fiesta/train/Benchmarker.py rename to src/fiestaEM/train/Benchmarker.py index d293b51..8e2b071 100644 --- a/src/fiesta/train/Benchmarker.py +++ b/src/fiestaEM/train/Benchmarker.py @@ -1,9 +1,9 @@ -from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel +from fiestaEM.inference.lightcurve_model import AfterglowpyLightcurvemodel import afterglowpy as grb -from fiesta.constants import days_to_seconds -from fiesta import conversions -from fiesta import utils -from fiesta.utils import Filter +from fiestaEM.constants import days_to_seconds +from fiestaEM import conversions +from fiestaEM import utils +from fiestaEM.utils import Filter from jaxtyping import Array, Float diff --git a/src/fiesta/train/SurrogateTrainer.py b/src/fiestaEM/train/SurrogateTrainer.py similarity index 97% rename from src/fiesta/train/SurrogateTrainer.py rename to src/fiestaEM/train/SurrogateTrainer.py index b703fa9..319ef24 100644 --- a/src/fiesta/train/SurrogateTrainer.py +++ b/src/fiestaEM/train/SurrogateTrainer.py @@ -6,13 +6,13 @@ import jax import jax.numpy as jnp from jaxtyping import Array, Float, Int -from fiesta.utils import MinMaxScalerJax -from fiesta import utils -from fiesta.utils import Filter -from fiesta import conversions -from fiesta.constants import days_to_seconds, c -from fiesta import models_utilities -import fiesta.train.neuralnets as fiesta_nn +from fiestaEM.utils import MinMaxScalerJax +from fiestaEM import utils +from fiestaEM.utils import Filter +from fiestaEM import conversions +from fiestaEM.constants import days_to_seconds, c +from fiestaEM import models_utilities +import fiestaEM.train.neuralnets as fiestaEM_nn import matplotlib.pyplot as plt import pickle @@ -51,7 +51,7 @@ class SurrogateTrainer: val_X_raw: Float[Array, "n_batch n_params"] val_y_raw: dict[str, Float[Array, "n_batch n_times"]] - trained_states: dict[str, fiesta_nn.TrainState] + trained_states: dict[str, fiestaEM_nn.TrainState] def __init__(self, name: str, @@ -107,7 +107,7 @@ def preprocess(self): print("Preprocessing data . . . done") def fit(self, - config: fiesta_nn.NeuralnetConfig = None, + config: fiestaEM_nn.NeuralnetConfig = None, key: jax.random.PRNGKey = jax.random.PRNGKey(0), verbose: bool = True): """ @@ -119,7 +119,7 @@ def fit(self, # Get default choices if no config is given if config is None: - config = fiesta_nn.NeuralnetConfig() + config = fiestaEM_nn.NeuralnetConfig() self.config = config trained_states = {} @@ -128,12 +128,12 @@ def fit(self, for filt in self.filters: # Create neural network and initialize the state - net = fiesta_nn.MLP(layer_sizes=config.layer_sizes) + net = fiestaEM_nn.MLP(layer_sizes=config.layer_sizes) key, subkey = jax.random.split(key) - state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) + state = fiestaEM_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config) # Perform training loop - state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) + state, train_losses, val_losses = fiestaEM_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose) # Plot and save the plot if so desired if self.plots_dir is not None: @@ -180,7 +180,7 @@ def save(self): for filt in self.filters: model = self.trained_states[filt.name] - fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl") + fiestaEM_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl") save[filt.name] = self.preprocessing_metadata[filt.name] with open(meta_filename, "wb") as meta_file: diff --git a/src/fiesta/train/neuralnets.py b/src/fiestaEM/train/neuralnets.py similarity index 100% rename from src/fiesta/train/neuralnets.py rename to src/fiestaEM/train/neuralnets.py diff --git a/src/fiesta/utils.py b/src/fiestaEM/utils.py similarity index 100% rename from src/fiesta/utils.py rename to src/fiestaEM/utils.py