Skip to content

Commit

Permalink
Basic cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
ThibeauWouters committed Oct 17, 2024
1 parent 4ae6867 commit 8403e9f
Show file tree
Hide file tree
Showing 9 changed files with 24 additions and 31 deletions.
6 changes: 5 additions & 1 deletion src/fiesta/conversions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,4 +18,8 @@ def mJys_to_mag_np(mJys: np.array):
def mJys_to_mag_jnp(mJys: Array):
Jys = 1e-3 * mJys
mag = -48.6 + -1 * jnp.log10(Jys / 1e23) * 2.5
return mag
return mag

def mag_app_from_mag_abs(mag_abs: Array,
luminosity_distance: Float) -> Array:
return mag_abs + 5.0 * jnp.log10(luminosity_distance * 1e6 / 10.0)
2 changes: 1 addition & 1 deletion src/fiesta/inference/fiesta.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.inference.prior import Prior
from fiesta.inference.likelihood import EMLikelihood
from fiesta.utils import mag_app_from_mag_abs
from fiesta.conversions import mag_app_from_mag_abs

from flowMC.sampler.Sampler import Sampler
from flowMC.sampler.MALA import MALA
Expand Down
2 changes: 1 addition & 1 deletion src/fiesta/inference/injection.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from jaxtyping import Float, Array

from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.utils import mag_app_from_mag_abs
from fiesta.conversions import mag_app_from_mag_abs
from fiesta.utils import Filter
from fiesta.constants import days_to_seconds, c
from fiesta import conversions
Expand Down
3 changes: 2 additions & 1 deletion src/fiesta/inference/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
import jax.numpy as jnp

from fiesta.inference.lightcurve_model import LightcurveModel
from fiesta.utils import mag_app_from_mag_abs, truncated_gaussian
from fiesta.utils import truncated_gaussian
from fiesta.conversions import mag_app_from_mag_abs

class EMLikelihood:

Expand Down
9 changes: 0 additions & 9 deletions src/fiesta/models_utilities.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,6 @@
"""Utilities regarding (surrogate) models."""
import jax
import jax.numpy as jnp
from jax.scipy.stats import truncnorm
from jaxtyping import Array, Float
import numpy as np
import pandas as pd
import scipy.interpolate as interp
import copy
import re
from sncosmo.bandpasses import _BANDPASSES
from astropy.time import Time

####################
### BULLA MODELS ###
Expand Down
17 changes: 12 additions & 5 deletions src/fiesta/train/Benchmarker.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from fiesta.inference.lightcurve_model import AfterglowpyLightcurvemodel
import afterglowpy as grb
from fiesta.constants import days_to_seconds, c
from fiesta.constants import days_to_seconds
from fiesta import conversions
from fiesta import utils
from fiesta.utils import Filter

from jaxtyping import Array, Float, Int
from jaxtyping import Array, Float

import tqdm
import os
Expand All @@ -15,8 +16,17 @@

from scipy.integrate import trapezoid

# TODO: get a benchmarker class for all surrogate model
class Benchmarker:

name: str
model_dir: str
filters: list[Filter]
n_test_data: int
metric_name: str
jet_type: int
model: AfterglowpyLightcurvemodel

def __init__(self,
name: str,
model_dir: str,
Expand Down Expand Up @@ -52,9 +62,6 @@ def __init__(self,
self.metric = lambda y: np.sqrt(trapezoid(x= self.times[mask],y=y[mask]**2))
else:
self.metric = lambda y: np.max(np.abs(y[mask]))




def __repr__(self) -> str:
return f"Surrogate_Benchmarker(name={self.name}, model_dir={self.model_dir})"
Expand Down
1 change: 0 additions & 1 deletion src/fiesta/train/SurrogateTrainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ class SurrogateTrainer:
val_X_raw: Float[Array, "n_batch n_params"]
val_y_raw: dict[str, Float[Array, "n_batch n_times"]]


trained_states: dict[str, fiesta_nn.TrainState]

def __init__(self,
Expand Down
3 changes: 1 addition & 2 deletions src/fiesta/train/neuralnets.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,8 +215,7 @@ def serialize(state: TrainState,
params = flax.serialization.to_state_dict(state)["params"]

serialized_dict = {"params": params,
"config": config,
}
"config": config}

return serialized_dict

Expand Down
12 changes: 2 additions & 10 deletions src/fiesta/utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import jax
import jax.numpy as jnp
from jax.scipy.stats import truncnorm
from jaxtyping import Array, Float
Expand All @@ -8,7 +7,6 @@
import copy
import re
from astropy.time import Time
from fiesta.constants import pc_to_cm
import astropy
import scipy
from sncosmo.bandpasses import _BANDPASSES, _BANDPASS_INTERPOLATORS
Expand Down Expand Up @@ -49,12 +47,6 @@ def inverse_svd_transform(x: Array,
# TODO: check the shapes etc, transforms and those things
return jnp.dot(VA[:, :nsvd_coeff], x)

# @jax.jit
# TODO: change place!
def mag_app_from_mag_abs(mag_abs: Array,
luminosity_distance: Float) -> Array:
return mag_abs + 5.0 * jnp.log10(luminosity_distance * 1e6 / 10.0)


#######################
### BULLA UTILITIES ###
Expand Down Expand Up @@ -126,8 +118,6 @@ def read_single_bulla_file(filename: str) -> dict:

return lc_data



#########################
### GENERAL UTILITIES ###
#########################
Expand Down Expand Up @@ -246,6 +236,7 @@ def __init__(self,
bandpass = sncosmo.get_bandpass(self.name)
self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10)
elif (self.name, None) in _BANDPASS_INTERPOLATORS._primary_loaders:
# FIXME: val undefined
bandpass = sncosmo.get_bandpass(val["name"], 3)
self.nu = scipy.constants.c/(bandpass.wave_eff*1e-10)
elif self.name.endswith("GHz"):
Expand Down Expand Up @@ -386,6 +377,7 @@ def get_default_filts_lambdas(filters: list[str]=None):
filts = filts_slice
lambdas = np.array(lambdas_slice)

# FIXME: transmittance undefined
return filts, lambdas, transmittance


Expand Down

0 comments on commit 8403e9f

Please sign in to comment.