Skip to content

Commit

Permalink
Rename TransformBase to TransformerBase
Browse files Browse the repository at this point in the history
  • Loading branch information
schmoelder committed Jan 22, 2025
1 parent b415222 commit a4e2d46
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 70 deletions.
99 changes: 67 additions & 32 deletions CADETProcess/optimization/optimizationProblem.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
ParallelizationBackendBase, SequentialBackend
)
from CADETProcess.transform import (
NoTransform, AutoTransform, NormLinearTransform, NormLogTransform
NullTransformer, AutoTransformer, NormLinearTransformer, NormLogTransformer
)

from CADETProcess.metric import MetricBase
Expand Down Expand Up @@ -1944,7 +1944,7 @@ def lower_bounds_transformed(self):
upper_bounds
"""
return [var.transform.lb for var in self.variables]
return [var.transformer.lb for var in self.variables]

@property
def lower_bounds_independent(self):
Expand All @@ -1966,7 +1966,7 @@ def lower_bounds_independent_transformed(self):
upper_bounds
"""
return [var.transform.lb for var in self.independent_variables]
return [var.transformer.lb for var in self.independent_variables]

@property
def upper_bounds(self):
Expand All @@ -1988,7 +1988,7 @@ def upper_bounds_transformed(self):
upper_bounds
"""
return [var._transform.ub for var in self.variables]
return [var.transformer.ub for var in self.variables]

@property
def upper_bounds_independent(self):
Expand All @@ -2010,7 +2010,7 @@ def upper_bounds_independent_transformed(self):
upper_bounds
"""
return [var._transform.ub for var in self.independent_variables]
return [var.transformer.ub for var in self.independent_variables]

@untransforms
@gets_dependent_values
Expand Down Expand Up @@ -2211,8 +2211,8 @@ def A_transformed(self):
A_t = self.A.copy()
for a in A_t:
for j, v in enumerate(self.variables):
t = v.transform
if isinstance(t, NoTransform):
t = v.transformer
if isinstance(t, NullTransformer):
continue

if a[j] != 0 and not t.is_linear:
Expand Down Expand Up @@ -2291,8 +2291,8 @@ def b_transformed(self):
A_t = self.A.copy()
for i, a in enumerate(A_t):
for j, v in enumerate(self.variables):
t = v.transform
if isinstance(t, NoTransform):
t = v.transformer
if isinstance(t, NullTransformer):
continue

if a[j] != 0 and not t.is_linear:
Expand Down Expand Up @@ -2510,8 +2510,8 @@ def Aeq_transformed(self):
Aeq_t = self.Aeq.copy()
for aeq in Aeq_t:
for j, v in enumerate(self.variables):
t = v.transform
if isinstance(t, NoTransform):
t = v.transformer
if isinstance(t, NullTransformer):
continue

if aeq[j] != 0 and not t.is_linear:
Expand Down Expand Up @@ -2589,8 +2589,8 @@ def beq_transformed(self):
Aeq_t = self.Aeq.copy()
for i, aeq in enumerate(Aeq_t):
for j, v in enumerate(self.variables):
t = v.transform
if isinstance(t, NoTransform):
t = v.transformer
if isinstance(t, NullTransformer):
continue

if aeq[j] != 0 and not t.is_linear:
Expand Down Expand Up @@ -2714,7 +2714,7 @@ def transform(self, X_independent: npt.ArrayLike) -> np.ndarray:

for i, ind in enumerate(X_independent):
transform[i, :] = [
var.transform_fun(value)
var.transform(value)
for value, var in zip(ind, self.independent_variables)
]

Expand All @@ -2738,7 +2738,7 @@ def untransform(self, X_transformed: npt.ArrayLike) -> np.ndarray:

for i, ind in enumerate(X_transformed):
untransform[i, :] = [
var.untransform_fun(value, significant_digits=var.significant_digits)
var.untransform(value, significant_digits=var.significant_digits)
for value, var in zip(ind, self.independent_variables)
]

Expand Down Expand Up @@ -2836,11 +2836,11 @@ def compute_negative_log_likelihood(self, x):

for i, var in enumerate(variables):
if (
isinstance(var._transform, NormLogTransform)
isinstance(var.transformer, NormLogTransformer)
or
(
isinstance(var._transform, AutoTransform) and
var._transform.use_log
isinstance(var.transformer, AutoTransformer) and
var.transformer.use_log
)
):
log_space_indices.append(i)
Expand Down Expand Up @@ -3221,7 +3221,7 @@ def check_linear_constraints_transforms(self):
for constr in self.linear_constraints + self.linear_equality_constraints:
opt_vars = [self.variables_dict[key] for key in constr["opt_vars"]]
for var in opt_vars:
if not var.transform.is_linear:
if not var.transformer.is_linear:
flag = False
warnings.warn(
f"'{var.name}' uses non-linear transform and is used in "
Expand Down Expand Up @@ -3385,7 +3385,7 @@ class OptimizationVariable:
Lower bound of the variable.
ub : float
Upper bound of the variable.
transform : TransformBase
transformer : TransformerBase
Transformation function for parameter normalization.
indices : int, or slice
Indices for variables that modify an entry of a parameter array.
Expand Down Expand Up @@ -3430,20 +3430,20 @@ def __init__(
self.ub = ub

if transform is None:
transform = NoTransform(lb, ub)
transformer = NullTransformer(lb, ub)
else:
if np.isinf(lb) or np.isinf(ub):
raise CADETProcessError("Transform requires bound constraints.")
if transform == 'auto':
transform = AutoTransform(lb, ub)
transformer = AutoTransformer(lb, ub)
elif transform == 'linear':
transform = NormLinearTransform(lb, ub)
transformer = NormLinearTransformer(lb, ub)
elif transform == 'log':
transform = NormLogTransform(lb, ub)
transformer = NormLogTransformer(lb, ub)
else:
raise ValueError("Unknown transform")

self._transform = transform
self._transformer = transformer
self.significant_digits = significant_digits

self._dependencies = []
Expand Down Expand Up @@ -3667,14 +3667,49 @@ def n_indices(self):
return 0

@property
def transform(self):
return self._transform
def transformer(self) -> "TransformerBase":
"""TransformerBase: The variable transformer instance."""
return self._transformer

def transform_fun(self, x, *args, **kwargs):
return self._transform.transform(x, *args, **kwargs)
def transform(self, x: float, *args: Any, **kwargs: Any) -> float:
"""
Apply the transformation to the input.
Parameters
----------
x : float
The input data to be transformed.
*args : Any
Additional positional arguments passed to the transformer's `transform` method.
**kwargs : Any
Additional keyword arguments passed to the transformer's `transform` method.
Returns
-------
float
The transformed data.
"""
return self.transformer.transform(x, *args, **kwargs)

def untransform_fun(self, x, *args, **kwargs):
return self._transform.untransform(x, *args, **kwargs)
def untransform(self, x: float, *args: Any, **kwargs: Any) -> float:
"""
Apply the inverse transformation to the input.
Parameters
----------
x : float
The input data to be untransformed.
*args : Any
Additional positional arguments passed to the transformer's `untransform` method.
**kwargs : Any
Additional keyword arguments passed to the transformer's `untransform` method.
Returns
-------
float
The untransformed data.
"""
return self.transformer.untransform(x, *args, **kwargs)

def add_dependency(self, dependencies, transform):
"""Add dependency of Variable on other Variables.
Expand Down Expand Up @@ -3829,7 +3864,7 @@ def _set_value_in_evaluation_object(self, evaluation_object, value):
@property
def transformed_bounds(self):
"""list: Transformed bounds of the parameter."""
return [self.transform_fun(self.lb), self.transform_fun(self.ub)]
return [self.transform(self.lb), self.transform(self.ub)]

def __repr__(self):
if self.evaluation_objects is not None:
Expand Down
2 changes: 1 addition & 1 deletion CADETProcess/solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,7 +282,7 @@ def update(self):
self._dm_dt_interpolated = None

def update_transform(self):
self.transform = transform.NormLinearTransform(
self.transform = transform.NormLinearTransformer(
np.min(self.solution, axis=0),
np.max(self.solution, axis=0),
allow_extended_input=True,
Expand Down
54 changes: 27 additions & 27 deletions CADETProcess/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
.. autosummary::
:toctree: generated/
TransformBase
NoTransform
NormLinearTransform
NormLogTransform
AutoTransform
TransformerBase
NullTransformer
NormLinearTransformer
NormLogTransformer
AutoTransformer
"""

Expand All @@ -29,7 +29,7 @@
from CADETProcess.numerics import round_to_significant_digits


class TransformBase(ABC):
class TransformerBase(ABC):
"""
Base class for parameter transformation.
Expand Down Expand Up @@ -65,11 +65,11 @@ class TransformBase(ABC):
Examples
--------
>>> class MyTransform(TransformBase):
>>> class MyTransformer(TransformerBase):
... def transform(self, x):
... return x ** 2
...
>>> t = MyTransform(lb_input=0, ub_input=10, lb=-100, ub=100)
>>> t = MyTransformer(lb_input=0, ub_input=10, lb=-100, ub=100)
>>> t.transform(3)
9
Expand All @@ -79,7 +79,7 @@ def __init__(
self,
lb_input=-np.inf, ub_input=np.inf,
allow_extended_input=False, allow_extended_output=False):
"""Initialize TransformBase
"""Initialize TransformerBase
Parameters
----------
Expand Down Expand Up @@ -275,14 +275,14 @@ def __str__(self):
return self.__class__.__name__


class NoTransform(TransformBase):
class NullTransformer(TransformerBase):
"""A class that implements no transformation.
Returns the input values without any transformation.
See Also
--------
TransformBase : The base class for parameter transformation.
TransformerBase : The base class for parameter transformation.
"""

@property
Expand Down Expand Up @@ -330,15 +330,15 @@ def _untransform(self, x):
return x


class NormLinearTransform(TransformBase):
class NormLinearTransformer(TransformerBase):
"""A class that implements a normalized linear transformation.
Transforms the input value to the range [0, 1] by normalizing it using
the lower and upper bounds of the input parameter space.
See Also
--------
TransformBase : The base class for parameter transformation.
TransformerBase : The base class for parameter transformation.
"""

Expand Down Expand Up @@ -387,15 +387,15 @@ def _untransform(self, x):
return (self.ub_input - self.lb_input) * x + self.lb_input


class NormLogTransform(TransformBase):
class NormLogTransformer(TransformerBase):
"""A class that implements a normalized logarithmic transformation.
Transforms the input value to the range [0, 1] using a logarithmic
transformation with the lower and upper bounds of the input parameter space.
See Also
--------
TransformBase : The base class for parameter transformation.
TransformerBase : The base class for parameter transformation.
"""

Expand Down Expand Up @@ -457,43 +457,43 @@ def _untransform(self, x):
self.lb_input * np.exp(x * np.log(self.ub_input/self.lb_input))


class AutoTransform(TransformBase):
class AutoTransformer(TransformerBase):
"""A class that implements an automatic parameter transformation.
Transforms the input value to the range [0, 1] using either
the :class:`NormLinearTransform` or the :class:`NormLogTransform`
the :class:`NormLinearTransformer` or the :class:`NormLogTransformer`
based on the input parameter space.
Attributes
----------
linear : :class:`NormLinearTransform`
linear : :class:`NormLinearTransformer`
Instance of the linear normalization transform.
log : :class:`NormLogTransform`
log : :class:`NormLogTransformer`
Instance of the logarithmic normalization transform.
See Also
--------
TransformBase
NormLinearTransform
NormLogTransform
TransformerBase
NormLinearTransformer
NormLogTransformer
"""

def __init__(self, *args, threshold=1000, **kwargs):
"""Initialize an AutoTransform object.
"""Initialize an AutoTransformer object.
Parameters
----------
*args : tuple
Arguments for the :class:`TransformBase` class.
Arguments for the :class:`TransformerBase` class.
threshold : int, optional
The maximum threshold to switch from linear to logarithmic
transformation. The default is 1000.
**kwargs : dict
Keyword arguments for the :class:`TransformBase` class.
Keyword arguments for the :class:`TransformerBase` class.
"""
self.linear = NormLinearTransform()
self.log = NormLogTransform()
self.linear = NormLinearTransformer()
self.log = NormLogTransformer()

super().__init__(*args, **kwargs)
self.threshold = threshold
Expand Down
Loading

0 comments on commit a4e2d46

Please sign in to comment.