Skip to content

Commit

Permalink
Be more strict about jax version, and change name from fiesta to fies…
Browse files Browse the repository at this point in the history
…taEM in src code
  • Loading branch information
ThibeauWouters committed Dec 19, 2024
1 parent a49e0db commit 65ce549
Show file tree
Hide file tree
Showing 19 changed files with 40 additions and 117 deletions.
4 changes: 2 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ packages_dir=
=src
packages = find:
install_requires =
jax>=0.4.24
jaxlib>=0.4.24
jax<=0.4.31
jaxlib<=0.4.31
numpy<2.0.0
pandas<2.0.0
jaxtyping
Expand Down
1 change: 1 addition & 0 deletions src/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
fiestaEM.egg-info/
54 changes: 0 additions & 54 deletions src/fiesta.egg-info/PKG-INFO

This file was deleted.

9 changes: 0 additions & 9 deletions src/fiesta.egg-info/SOURCES.txt

This file was deleted.

1 change: 0 additions & 1 deletion src/fiesta.egg-info/dependency_links.txt

This file was deleted.

13 changes: 0 additions & 13 deletions src/fiesta.egg-info/requires.txt

This file was deleted.

1 change: 0 additions & 1 deletion src/fiesta.egg-info/top_level.txt

This file was deleted.

File renamed without changes.
2 changes: 1 addition & 1 deletion src/fiesta/conversions.py → src/fiestaEM/conversions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from fiesta.constants import pc_to_cm
from fiestaEM.constants import pc_to_cm
import jax
import jax.numpy as jnp
from jaxtyping import Array
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import jax.numpy as jnp
from jaxtyping import Float, Array, PRNGKeyArray

from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.inference.prior import Prior
from fiesta.inference.likelihood import EMLikelihood
from fiesta.conversions import mag_app_from_mag_abs
from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.inference.prior import Prior
from fiestaEM.inference.likelihood import EMLikelihood
from fiestaEM.conversions import mag_app_from_mag_abs

from flowMC.sampler.Sampler import Sampler
from flowMC.sampler.MALA import MALA
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,11 +8,11 @@
import jax.numpy as jnp
from jaxtyping import Float, Array

from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.conversions import mag_app_from_mag_abs
from fiesta.utils import Filter
from fiesta.constants import days_to_seconds, c
from fiesta import conversions
from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.conversions import mag_app_from_mag_abs
from fiestaEM.utils import Filter
from fiestaEM.constants import days_to_seconds, c
from fiestaEM import conversions

import afterglowpy as grb

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@
from flax.training.train_state import TrainState
import pickle

import fiesta.train.neuralnets as fiesta_nn
from fiesta.utils import MinMaxScalerJax, inverse_svd_transform
import fiesta.conversions as conversions
from fiesta import models_utilities
import fiestaEM.train.neuralnets as fiestaEM_nn
from fiestaEM.utils import MinMaxScalerJax, inverse_svd_transform
import fiestaEM.conversions as conversions
from fiestaEM import models_utilities

########################
### ABSTRACT CLASSES ###
Expand Down Expand Up @@ -184,7 +184,7 @@ def load_networks(self) -> None:
self.models = {}
for filter in self.filters:
filename = os.path.join(self.directory, f"{filter}.pkl")
state, _ = fiesta_nn.load_model(filename)
state, _ = fiestaEM_nn.load_model(filename)
self.models[filter] = state

def load_parameter_names(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,9 @@
from jaxtyping import Float, Array
import jax.numpy as jnp

from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.utils import truncated_gaussian
from fiesta.conversions import mag_app_from_mag_abs
from fiestaEM.inference.lightcurve_model import LightcurveModel
from fiestaEM.utils import truncated_gaussian
from fiestaEM.conversions import mag_app_from_mag_abs

class EMLikelihood:

Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel
from fiestaEM.inference.lightcurve_model import AfterglowpyLightcurvemodel
import afterglowpy as grb
from fiesta.constants import days_to_seconds
from fiesta import conversions
from fiesta import utils
from fiesta.utils import Filter
from fiestaEM.constants import days_to_seconds
from fiestaEM import conversions
from fiestaEM import utils
from fiestaEM.utils import Filter

from jaxtyping import Array, Float

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,13 @@
import jax
import jax.numpy as jnp
from jaxtyping import Array, Float, Int
from fiesta.utils import MinMaxScalerJax
from fiesta import utils
from fiesta.utils import Filter
from fiesta import conversions
from fiesta.constants import days_to_seconds, c
from fiesta import models_utilities
import fiesta.train.neuralnets as fiesta_nn
from fiestaEM.utils import MinMaxScalerJax
from fiestaEM import utils
from fiestaEM.utils import Filter
from fiestaEM import conversions
from fiestaEM.constants import days_to_seconds, c
from fiestaEM import models_utilities
import fiestaEM.train.neuralnets as fiestaEM_nn

import matplotlib.pyplot as plt
import pickle
Expand Down Expand Up @@ -51,7 +51,7 @@ class SurrogateTrainer:
val_X_raw: Float[Array, "n_batch n_params"]
val_y_raw: dict[str, Float[Array, "n_batch n_times"]]

trained_states: dict[str, fiesta_nn.TrainState]
trained_states: dict[str, fiestaEM_nn.TrainState]

def __init__(self,
name: str,
Expand Down Expand Up @@ -107,7 +107,7 @@ def preprocess(self):
print("Preprocessing data . . . done")

def fit(self,
config: fiesta_nn.NeuralnetConfig = None,
config: fiestaEM_nn.NeuralnetConfig = None,
key: jax.random.PRNGKey = jax.random.PRNGKey(0),
verbose: bool = True):
"""
Expand All @@ -119,7 +119,7 @@ def fit(self,

# Get default choices if no config is given
if config is None:
config = fiesta_nn.NeuralnetConfig()
config = fiestaEM_nn.NeuralnetConfig()
self.config = config

trained_states = {}
Expand All @@ -128,12 +128,12 @@ def fit(self,
for filt in self.filters:

# Create neural network and initialize the state
net = fiesta_nn.MLP(layer_sizes=config.layer_sizes)
net = fiestaEM_nn.MLP(layer_sizes=config.layer_sizes)
key, subkey = jax.random.split(key)
state = fiesta_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config)
state = fiestaEM_nn.create_train_state(net, jnp.ones(input_ndim), subkey, config)

# Perform training loop
state, train_losses, val_losses = fiesta_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose)
state, train_losses, val_losses = fiestaEM_nn.train_loop(state, config, self.train_X, self.train_y[filt.name], self.val_X, self.val_y[filt.name], verbose=verbose)

# Plot and save the plot if so desired
if self.plots_dir is not None:
Expand Down Expand Up @@ -180,7 +180,7 @@ def save(self):

for filt in self.filters:
model = self.trained_states[filt.name]
fiesta_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl")
fiestaEM_nn.save_model(model, self.config, out_name=self.outdir + f"{filt.name}.pkl")
save[filt.name] = self.preprocessing_metadata[filt.name]

with open(meta_filename, "wb") as meta_file:
Expand Down
File renamed without changes.
File renamed without changes.

0 comments on commit 65ce549

Please sign in to comment.