Skip to content

Commit

Permalink
Eng 399 - Introducing Metric Nodes in Designer (#247)
Browse files Browse the repository at this point in the history
* implemented designer metric node

* function test for designer metric node

* Metric node populated on fetch

* designer node construction fix

* minor

* minor bug fix

* Fixing data asset reference in metric pipeline test

* minor bug fixes

* Remove undesired print

---------

Co-authored-by: Thiago Castro Ferreira <[email protected]>
  • Loading branch information
kadirpekel and thiago-aixplain authored Sep 10, 2024
1 parent 357e10d commit 731a150
Show file tree
Hide file tree
Showing 9 changed files with 155 additions and 25 deletions.
9 changes: 6 additions & 3 deletions aixplain/factories/pipeline_factory/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from aixplain.modules.pipeline.designer import (
Input,
Output,
AssetNode,
BareAsset,
BareMetric,
Decision,
Router,
Route,
Expand Down Expand Up @@ -36,14 +37,16 @@ def build_from_response(response: Dict, load_architecture: bool = False) -> Pipe
try:
# instantiating nodes
for node_json in response["nodes"]:
print(node_json)
if node_json["type"].lower() == "input":
node = Input(
data=node_json["data"] if "data" in node_json else None,
data_types=[DataType(dt) for dt in node_json["dataType"]],
)
elif node_json["type"].lower() == "asset":
node = AssetNode(asset_id=node_json["assetId"])
if node_json["functionType"] == "metric":
node = BareMetric(asset_id=node_json["assetId"])
else:
node = BareAsset(asset_id=node_json["assetId"])
elif node_json["type"].lower() == "segmentor":
raise NotImplementedError()
elif node_json["type"].lower() == "reconstructor":
Expand Down
6 changes: 6 additions & 0 deletions aixplain/modules/pipeline/designer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,9 @@
Router,
BaseReconstructor,
BaseSegmentor,
BaseMetric,
BareAsset,
BareMetric
)
from .pipeline import DesignerPipeline
from .base import (
Expand Down Expand Up @@ -36,6 +39,7 @@
__all__ = [
"DesignerPipeline",
"AssetNode",
"BareAsset",
"Decision",
"Script",
"Input",
Expand Down Expand Up @@ -63,4 +67,6 @@
"ParamProxy",
"TI",
"TO",
"BaseMetric",
"BareMetric"
]
9 changes: 6 additions & 3 deletions aixplain/modules/pipeline/designer/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,8 +188,8 @@ def validate(self):
if from_param.data_type and to_param.data_type:
if from_param.data_type != to_param.data_type:
raise ValueError(
f"Data type mismatch between {from_param.data_type} and {to_param.data_type}"
) # noqa
f"Data type mismatch between {from_param.data_type} and {to_param.data_type}" # noqa
)

def attach_to(self, pipeline: "DesignerPipeline"):
"""
Expand Down Expand Up @@ -344,6 +344,9 @@ def __init__(
if pipeline:
self.attach_to(pipeline)

def build_label(self):
return f"{self.type.value}(ID={self.number})"

def attach_to(self, pipeline: "DesignerPipeline"):
"""
Attach the node to the pipeline.
Expand All @@ -359,7 +362,7 @@ def attach_to(self, pipeline: "DesignerPipeline"):
if self.number is None:
self.number = len(pipeline.nodes)
if self.label is None:
self.label = f"{self.type.value}(ID={self.number})"
self.label = self.build_label()

assert not pipeline.get_node(self.number), "Node number already exists"
pipeline.nodes.append(self)
Expand Down
1 change: 1 addition & 0 deletions aixplain/modules/pipeline/designer/enums.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class FunctionType(str, Enum):
AI = "AI"
SEGMENTOR = "SEGMENTOR"
RECONSTRUCTOR = "RECONSTRUCTOR"
METRIC = "METRIC"


class ParamType:
Expand Down
72 changes: 62 additions & 10 deletions aixplain/modules/pipeline/designer/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,9 @@ def __init__(
supplier: str = None,
version: str = None,
pipeline: "DesignerPipeline" = None,
**kwargs
):
super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)
self.asset_id = asset_id
self.supplier = supplier
self.version = version
Expand Down Expand Up @@ -85,8 +86,8 @@ def populate_asset(self):
if self.function:
if self.asset.function.value != self.function:
raise ValueError(
f"Function {self.function} is not supported by asset {self.asset_id}"
) # noqa
f"Function {self.function} is not supported by asset {self.asset_id}" # noqa
)
else:
self.function = self.asset.function.value
self._auto_populate_params()
Expand Down Expand Up @@ -129,6 +130,18 @@ def serialize(self) -> dict:
return obj


class BareAssetInputs(Inputs):
pass


class BareAssetOutputs(Outputs):
pass


class BareAsset(AssetNode[BareAssetInputs, BareAssetOutputs]):
pass


class InputInputs(Inputs):
pass

Expand Down Expand Up @@ -163,10 +176,11 @@ def __init__(
data: Optional[str] = None,
data_types: Optional[List[DataType]] = None,
pipeline: "DesignerPipeline" = None,
**kwargs
):
from aixplain.factories.file_factory import FileFactory

super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)
self.data_types = data_types or []
self.data = data

Expand Down Expand Up @@ -209,8 +223,9 @@ def __init__(
self,
data_types: Optional[List[DataType]] = None,
pipeline: "DesignerPipeline" = None,
**kwargs
):
super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)
self.data_types = data_types or []

def serialize(self) -> dict:
Expand All @@ -237,10 +252,11 @@ def __init__(
pipeline: "DesignerPipeline" = None,
script_path: Optional[str] = None,
fileId: Optional[str] = None,
**kwargs
):
from aixplain.factories.script_factory import ScriptFactory

super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)

assert script_path or fileId, "script_path or fileId is required"

Expand Down Expand Up @@ -272,6 +288,7 @@ def __init__(
path: List[Union[Node, int]],
operation: Operation,
type: RouteType,
**kwargs
):
"""
Post init method to convert the nodes to node numbers if they are
Expand Down Expand Up @@ -328,9 +345,10 @@ class Router(Node[RouterInputs, RouterOutputs], LinkableMixin):
outputs_class: Type[TO] = RouterOutputs

def __init__(
self, routes: List[Route], pipeline: "DesignerPipeline" = None
self, routes: List[Route], pipeline: "DesignerPipeline" = None,
**kwargs
):
super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)
self.routes = routes

def serialize(self) -> dict:
Expand Down Expand Up @@ -369,9 +387,10 @@ class Decision(Node[DecisionInputs, DecisionOutputs], LinkableMixin):
outputs_class: Type[TO] = DecisionOutputs

def __init__(
self, routes: List[Route], pipeline: "DesignerPipeline" = None
self, routes: List[Route], pipeline: "DesignerPipeline" = None,
**kwargs
):
super().__init__(pipeline=pipeline)
super().__init__(pipeline=pipeline, **kwargs)
self.routes = routes

def link(
Expand Down Expand Up @@ -462,3 +481,36 @@ class BareReconstructor(
functionType: FunctionType = FunctionType.RECONSTRUCTOR
inputs_class: Type[TI] = ReconstructorInputs
outputs_class: Type[TO] = ReconstructorOutputs


class BaseMetric(AssetNode[TI, TO]):
functionType: FunctionType = FunctionType.METRIC

def build_label(self):
return f"METRIC({self.number})"


class MetricInputs(Inputs):

hypotheses: InputParam = None
references: InputParam = None
sources: InputParam = None

def __init__(self, node: Node):
super().__init__(node)
self.hypotheses = self.create_param("hypotheses")
self.references = self.create_param("references")
self.sources = self.create_param("sources")


class MetricOutputs(Outputs):

data: OutputParam = None

def __init__(self, node: Node):
super().__init__(node)
self.data = self.create_param("data")


class BareMetric(BaseMetric[MetricInputs, MetricOutputs]):
pass
12 changes: 12 additions & 0 deletions aixplain/modules/pipeline/designer/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
Route,
BareReconstructor,
BareSegmentor,
BareMetric
)
from .enums import NodeType, RouteType, Operation

Expand Down Expand Up @@ -326,3 +327,14 @@ def bare_segmentor(self, *args, **kwargs) -> BareSegmentor:
:return: the node
"""
return BareSegmentor(*args, pipeline=self, **kwargs)

def metric(self, *args, **kwargs) -> BareMetric:
"""
Shortcut to create an metric node for the current pipeline.
All params will be passed as keyword arguments to the node
constructor.
:param kwargs: keyword arguments
:return: the node
"""
return BareMetric(*args, pipeline=self, **kwargs)
3 changes: 3 additions & 0 deletions aixplain/modules/pipeline/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
AssetNode,
BaseReconstructor,
BaseSegmentor,
BaseMetric
)
from .default import DefaultPipeline
from aixplain.modules import asset
Expand Down Expand Up @@ -160,6 +161,8 @@ def populate_specs(functions: list):
base_class = "BaseSegmentor"
elif is_reconstructor:
base_class = "BaseReconstructor"
elif "metric" in function_name.split("_"): # noqa: Advise a better distinguisher please
base_class = "BaseMetric"

spec = {
"id": function["id"],
Expand Down
17 changes: 9 additions & 8 deletions aixplain/modules/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
AssetNode,
BaseReconstructor,
BaseSegmentor,
BaseMetric
)
from .default import DefaultPipeline
from aixplain.modules import asset
Expand Down Expand Up @@ -907,7 +908,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class ReferencelessAudioGenerationMetric(AssetNode[ReferencelessAudioGenerationMetricInputs, ReferencelessAudioGenerationMetricOutputs]):
class ReferencelessAudioGenerationMetric(BaseMetric[ReferencelessAudioGenerationMetricInputs, ReferencelessAudioGenerationMetricOutputs]):
"""
The Referenceless Audio Generation Metric is a tool designed to evaluate the
quality of generated audio content without the need for a reference or original
Expand Down Expand Up @@ -1080,7 +1081,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class AudioGenerationMetric(AssetNode[AudioGenerationMetricInputs, AudioGenerationMetricOutputs]):
class AudioGenerationMetric(BaseMetric[AudioGenerationMetricInputs, AudioGenerationMetricOutputs]):
"""
The Audio Generation Metric is a quantitative measure used to evaluate the
quality, accuracy, and overall performance of audio generated by artificial
Expand Down Expand Up @@ -1471,7 +1472,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class MetricAggregation(AssetNode[MetricAggregationInputs, MetricAggregationOutputs]):
class MetricAggregation(BaseMetric[MetricAggregationInputs, MetricAggregationOutputs]):
"""
Metric Aggregation is a function that computes and summarizes numerical data by
applying statistical operations, such as averaging, summing, or finding the
Expand Down Expand Up @@ -1790,7 +1791,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class ReferencelessTextGenerationMetric(AssetNode[ReferencelessTextGenerationMetricInputs, ReferencelessTextGenerationMetricOutputs]):
class ReferencelessTextGenerationMetric(BaseMetric[ReferencelessTextGenerationMetricInputs, ReferencelessTextGenerationMetricOutputs]):
"""
The Referenceless Text Generation Metric is a method for evaluating the quality
of generated text without requiring a reference text for comparison, often
Expand Down Expand Up @@ -1830,7 +1831,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class TextGenerationMetricDefault(AssetNode[TextGenerationMetricDefaultInputs, TextGenerationMetricDefaultOutputs]):
class TextGenerationMetricDefault(BaseMetric[TextGenerationMetricDefaultInputs, TextGenerationMetricDefaultOutputs]):
"""
The "Text Generation Metric Default" function provides a standard set of
evaluation metrics for assessing the quality and performance of text generation
Expand Down Expand Up @@ -2130,7 +2131,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class TextGenerationMetric(AssetNode[TextGenerationMetricInputs, TextGenerationMetricOutputs]):
class TextGenerationMetric(BaseMetric[TextGenerationMetricInputs, TextGenerationMetricOutputs]):
"""
A Text Generation Metric is a quantitative measure used to evaluate the quality
and effectiveness of text produced by natural language processing models, often
Expand Down Expand Up @@ -2981,7 +2982,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.TEXT)


class ReferencelessTextGenerationMetricDefault(AssetNode[ReferencelessTextGenerationMetricDefaultInputs, ReferencelessTextGenerationMetricDefaultOutputs]):
class ReferencelessTextGenerationMetricDefault(BaseMetric[ReferencelessTextGenerationMetricDefaultInputs, ReferencelessTextGenerationMetricDefaultOutputs]):
"""
The Referenceless Text Generation Metric Default is a function designed to
evaluate the quality of generated text without relying on reference texts for
Expand Down Expand Up @@ -3665,7 +3666,7 @@ def __init__(self, node=None):
self.data = self.create_param(code="data", data_type=DataType.NUMBER)


class ClassificationMetric(AssetNode[ClassificationMetricInputs, ClassificationMetricOutputs]):
class ClassificationMetric(BaseMetric[ClassificationMetricInputs, ClassificationMetricOutputs]):
"""
A Classification Metric is a quantitative measure used to evaluate the quality
and effectiveness of classification models.
Expand Down
Loading

0 comments on commit 731a150

Please sign in to comment.