Skip to content

Commit

Permalink
Add peak regression configs
Browse files Browse the repository at this point in the history
  • Loading branch information
LukasMahieu committed Jun 15, 2024
1 parent 2e0aeaa commit 5f2da8f
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 1 deletion.
34 changes: 33 additions & 1 deletion src/crested/tl/_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,14 @@

import tensorflow as tf

from crested.tl.losses import CosineMSELoss
from crested.tl.metrics import (
ConcordanceCorrelationCoefficient,
PearsonCorrelation,
PearsonCorrelationLog,
ZeroPenaltyMetric,
)


class BaseConfig(ABC):
"""Base configuration class for tasks."""
Expand Down Expand Up @@ -63,6 +71,30 @@ def metrics(self) -> list[tf.keras.metrics.Metric]:
]


class PeakRegressionConfig(BaseConfig):
"""Default configuration for peak regression task."""

@property
def loss(self) -> tf.keras.losses.Loss:
return CosineMSELoss()

@property
def optimizer(self) -> tf.keras.optimizers.Optimizer:
return tf.keras.optimizers.Adam(learning_rate=1e-3)

@property
def metrics(self) -> list[tf.keras.metrics.Metric]:
return [
tf.keras.metrics.MeanAbsoluteError(),
tf.keras.metrics.MeanSquaredError(),
tf.keras.metrics.CosineSimilarity(axis=1),
PearsonCorrelation(),
ConcordanceCorrelationCoefficient(),
PearsonCorrelationLog(),
ZeroPenaltyMetric(),
]


class TaskConfig(NamedTuple):
"""
Task configuration (optimizer, loss, and metrics) for use in tl.Crested.
Expand Down Expand Up @@ -127,7 +159,7 @@ def default_configs(
"""
task_classes = {
"topic_classification": TopicClassificationConfig,
# Add other tasks and their corresponding classes here
"peak_regression": PeakRegressionConfig,
}

if task not in task_classes:
Expand Down
6 changes: 6 additions & 0 deletions src/crested/tl/losses/_cosinemse.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
"""Default loss for peak regression task."""

from __future__ import annotations

import tensorflow as tf
from tensorflow.keras.losses import Loss, Reduction
from tensorflow.python.keras import backend as K


class CosineMSELoss(Loss):
"""Custom loss function that combines cosine similarity and mean squared error."""

def __init__(self, reduction=Reduction.SUM, name="CosineMSELoss"):
super().__init__(reduction=reduction, name=name)

Expand Down

0 comments on commit 5f2da8f

Please sign in to comment.