Skip to content

Commit

Permalink
I3Module and I3Segments with tf v2 changes
Browse files Browse the repository at this point in the history
  • Loading branch information
mhuen committed Feb 24, 2025
1 parent a278be6 commit ab06dd1
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 75 deletions.
17 changes: 5 additions & 12 deletions dnn_reco/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,17 +237,8 @@ def export_data_settings(data_settings, output_folder, config):
config : dict
Configuration of the NN model.
"""
try:
with open(data_settings, "r") as stream:
data_config = yaml_loader.load(stream)
except Exception as e:
print(e)
print("Falling back to modified SafeLoader")
with open(data_settings, "r") as stream:
# yaml.SafeLoader.add_constructor(
# "tag:yaml.org,2002:python/unicode", lambda _, node: node.value
# )
data_config = dict(yaml_loader.load(stream))
with open(data_settings, "r") as stream:
data_config = yaml_loader.load(stream)

for k in [
"pulse_time_quantiles",
Expand All @@ -273,6 +264,9 @@ def export_data_settings(data_settings, output_folder, config):
print("Could not extract data settings. Aborting export!")
raise e

if "is_str_dom_format" not in data_settings:
data_settings["is_str_dom_format"] = False

print("\n=========================")
print("= Found Data Settings: =")
print("=========================")
Expand Down Expand Up @@ -309,7 +303,6 @@ def ic3_processing_scripts(data_config, config):
for segment_cfg in step_config["tray_segments"]:
if segment_cfg["ModuleClass"] == "ic3_data.segments.CreateDNNData":

print("segment_cfg", segment_cfg)
# check correct output names
if "OutputKey" in segment_cfg["ModuleKwargs"]:
base = segment_cfg["ModuleKwargs"]["OutputKey"]
Expand Down
7 changes: 3 additions & 4 deletions dnn_reco/ic3/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,7 @@ def Configure(self):
self.config = setup_manager.get_config()

# ToDo: Adjust necessary values in config
self.config["model_checkpoint_path"] = os.path.join(
self._model_path, "model"
)
self.config["model_checkpoint_path"] = self._model_path
self.config["model_kwargs"]["is_training"] = False
self.config["trafo_model_path"] = os.path.join(
self._model_path, "trafo_model.npy"
Expand Down Expand Up @@ -223,14 +221,15 @@ def Configure(self):
config=self.config,
data_handler=self.data_handler,
data_transformer=self.data_transformer,
verbose=False,
**model_kwargs
)

# compile model: initialize and finalize graph
self.model.compile()

# restore model weights
self.model.restore()
self.model.restore(is_training=False)

# Get trained labels, e.g. labels with weights greater than zero
self._mask_labels = (
Expand Down
125 changes: 68 additions & 57 deletions dnn_reco/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,11 +105,11 @@ def __init__(
os.path.join(self.config["log_path"], "val")
)

# create label weights and non zero mask
self._create_label_weights()
# create label weights and non zero mask
self._create_label_weights()

# create variables necessary for tukey loss
self._create_tukey_vars()
# create variables necessary for tukey loss
self._create_tukey_vars()

def _setup_directories(self):
"""Creates necessary directories"""
Expand Down Expand Up @@ -207,11 +207,12 @@ def _create_label_weights(self):

self.shared_objects["label_weights"] = label_weights

misc.print_warning(
"Total Benchmark should be: {:3.3f}".format(
sum(self.shared_objects["label_weight_config"])
if self.is_training:
misc.print_warning(
"Total Benchmark should be: {:3.3f}".format(
sum(self.shared_objects["label_weight_config"])
)
)
)

def _update_tukey_vars(self, new_values, tukey_decay=0.001):
"""Update tukey variables"""
Expand Down Expand Up @@ -779,7 +780,7 @@ def compile(self):
msg += f"\tTotal: {num_total_vars}"
self._logger.info(msg)

def restore(self):
def restore(self, is_training=True):
"""Restore model weights from checkpoints"""
latest_checkpoint = self._checkpoint_manager.latest_checkpoint
if latest_checkpoint is None:
Expand All @@ -790,7 +791,63 @@ def restore(self):
self._logger.info(
f"[Model] Loading checkpoint: {latest_checkpoint}"
)
self._checkpoint.restore(latest_checkpoint).assert_consumed()
status = self._checkpoint.restore(latest_checkpoint)
if is_training:
status.assert_consumed()
else:
status.expect_partial()

def predict(
self, x_ic78, x_deepcore, transformed=False, is_training=False
):
"""Reconstruct events.
Parameters
----------
x_ic78 : float, list or numpy.ndarray
The input data for the main IceCube array.
x_deepcore : float, list or numpy.ndarray
The input data for the DeepCore array.
transformed : bool, optional
If true, the normalized and transformed values are returned.
is_training : bool, optional
True if model is in training mode, false if in inference mode.
Returns
-------
np.ndarray, np.ndarray
The prediction and estimated uncertainties
"""
data_batch_dict = {
"x_ic78": x_ic78,
"x_deepcore": x_deepcore,
"x_ic78_trafo": self.data_transformer.transform(
x_ic78, data_type="ic78"
),
"x_deepcore_trafo": self.data_transformer.transform(
x_deepcore, data_type="deepcore"
),
}
result_tensors = self(data_batch_dict, is_training=is_training)

if transformed:
return_values = (
result_tensors["y_pred_trafo"].numpy(),
result_tensors["y_unc_pred_trafo"].numpy(),
)
else:
# transform back
y_pred = self.data_transformer.inverse_transform(
result_tensors["y_pred_trafo"], data_type="label"
).numpy()
y_unc = self.data_transformer.inverse_transform(
result_tensors["y_unc_pred_trafo"],
data_type="label",
bias_correction=False,
).numpy()
return_values = (y_pred, y_unc)

return return_values

def predict_batched(
self, x_ic78, x_deepcore, max_size, transformed=False, *args, **kwargs
Expand Down Expand Up @@ -841,52 +898,6 @@ def predict_batched(

return y_pred, y_unc

def predict(
self, x_ic78, x_deepcore, transformed=False, is_training=False
):
"""Reconstruct events.
Parameters
----------
x_ic78 : float, list or numpy.ndarray
The input data for the main IceCube array.
x_deepcore : float, list or numpy.ndarray
The input data for the DeepCore array.
transformed : bool, optional
If true, the normalized and transformed values are returned.
is_training : bool, optional
True if model is in training mode, false if in inference mode.
Returns
-------
np.ndarray, np.ndarray
The prediction and estimated uncertainties
"""
data_batch_dict = {
"x_ic78": x_ic78,
"x_deepcore": x_deepcore,
}
result_tensors = self(data_batch_dict, is_training=is_training)

if transformed:
return_values = (
result_tensors["y_pred_trafo"].numpy(),
result_tensors["y_unc_pred_trafo"].numpy(),
)
else:
# transform back
y_pred = self.data_transformer.inverse_transform(
result_tensors["y_pred_trafo"], data_type="label"
).numpy()
y_unc = self.data_transformer.inverse_transform(
result_tensors["y_unc_pred_trafo"],
data_type="label",
bias_correction=False,
).numpy()
return_values = (y_pred, y_unc)

return return_values

def fit(
self,
num_training_iterations,
Expand Down Expand Up @@ -1161,7 +1172,7 @@ def _save_training_config(self, iteration):
ValueError
Description
"""
if iteration == 0:
if iteration <= self.config["save_frequency"]:
if not self.config["model_restore_model"]:
# Delete old training config files and create a new and empty
# training_steps.txt, since we are training a new model
Expand Down
4 changes: 2 additions & 2 deletions dnn_reco/modules/models/general_IC86_cnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,11 +325,11 @@ def _apply_convolutions(
tf.Tensor
The flattened and combined output of the convolutional layers.
"""
x_ic78_trafo = tf.convert_to_tensor(
x_ic78_trafo = tf.cast(
data_batch_dict["x_ic78_trafo"],
dtype=self.dtype,
)
x_deepcore_trafo = tf.convert_to_tensor(
x_deepcore_trafo = tf.cast(
data_batch_dict["x_deepcore_trafo"],
dtype=self.dtype,
)
Expand Down

0 comments on commit ab06dd1

Please sign in to comment.