Skip to content

Commit

Permalink
fix: add support for draft field name converter logic in DioptraClient
Browse files Browse the repository at this point in the history
This commit extends the `NewResourceDraftsSubCollectionClient` and
`ModifyResourceDraftsSubCollectionClient` classes to accept a new `convert_field_names_fn` argument,
which takes a callable that implements the `ConvertFieldNamesToCamelCaseProtocol` interface. A
utility function `make_field_names_to_camel_case_converter` that builds a callable that uses a
Python dictionary for the name conversions and checks for field name collisions during conversion is
provided. A new conversion callable is added to the entrypoints collection client to handle
converting task_graph to taskGraph. The existing integration tests for the entrypoints drafts
functionality have been updated to use task_graph instead of taskGraph, and a new unit test has been
added to validate that a user cannot pass task_graph and taskGraph as arguments to the entrypoint
drafts clients at the same time.

Closes #727
  • Loading branch information
jkglasbrenner authored Jan 30, 2025
1 parent a26566b commit c8a4c8d
Show file tree
Hide file tree
Showing 4 changed files with 190 additions and 16 deletions.
4 changes: 4 additions & 0 deletions src/dioptra/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class DioptraClientError(Exception):
"""Base class for client errors"""


class FieldNameCollisionError(DioptraClientError):
"""Raised when two field names will collide after conversion to camel case."""


class FieldsValidationError(DioptraClientError):
"""Raised when one or more fields are invalid."""

Expand Down
70 changes: 65 additions & 5 deletions src/dioptra/client/drafts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
CollectionClient,
DioptraClientError,
DioptraSession,
FieldNameCollisionError,
FieldsValidationError,
SubCollectionClient,
SubCollectionUrlError,
Expand All @@ -33,11 +34,61 @@ class DraftFieldsValidationError(DioptraClientError):
"""Raised when one or more draft fields are invalid."""


class ConvertFieldNamesToCamelCaseProtocol(Protocol):
def __call__(self, json_: dict[str, Any]) -> dict[str, Any]:
... # fmt: skip


class ValidateDraftFieldsProtocol(Protocol):
def __call__(self, json_: dict[str, Any]) -> dict[str, Any]:
... # fmt: skip


def make_field_names_to_camel_case_converter(
name_mapping: dict[str, str],
) -> ConvertFieldNamesToCamelCaseProtocol:
"""
Create a function that uses a dictionary to convert field names passed to the client
to camel case.
Args:
name_mapping: A dictionary that maps the draft field names to their camel case
equivalents.
Returns:
A function for converting the draft fields names to camel case.
"""

def convert_field_names(json_: dict[str, Any]) -> dict[str, Any]:
"""Convert the names of the provided draft fields to camel case.
Args:
json_: The draft fields to convert.
Returns:
The draft fields with names converted to camel case.
"""
if not name_mapping:
return json_

converted_json_: dict[str, Any] = {}
conversions_seen: dict[str, str] = {}

for key, value in json_.items():
if (converted_key := name_mapping.get(key, key)) in conversions_seen:
raise FieldNameCollisionError(
f"Camel case conversion failed (reason: duplicates "
f"{conversions_seen[converted_key]} after conversion): {key}"
)

converted_json_[converted_key] = value
conversions_seen[converted_key] = key

return converted_json_

return convert_field_names


def make_draft_fields_validator(
draft_fields: set[str], resource_name: str
) -> ValidateDraftFieldsProtocol:
Expand Down Expand Up @@ -102,6 +153,7 @@ def __init__(
validate_fields_fn: ValidateDraftFieldsProtocol,
root_collection: CollectionClient[T],
parent_sub_collections: list[SubCollectionClient[T]] | None = None,
convert_field_names_fn: ConvertFieldNamesToCamelCaseProtocol | None = None,
) -> None:
"""Initialize the NewResourceDraftsSubCollectionClient instance.
Expand All @@ -122,6 +174,9 @@ def __init__(
self._parent_sub_collections: list[SubCollectionClient[T]] = (
parent_sub_collections or []
)
self._convert_field_names = (
convert_field_names_fn or make_field_names_to_camel_case_converter({})
)

def get(
self,
Expand Down Expand Up @@ -198,8 +253,7 @@ def create(
DraftFieldsValidationError: If one or more draft fields are invalid or
missing.
"""

if "group" in kwargs:
if "group" in (kwargs := self._convert_field_names(kwargs)):
raise FieldsValidationError(
"Invalid argument (reason: keyword is reserved): group"
)
Expand Down Expand Up @@ -228,7 +282,7 @@ def modify(self, *resource_ids: str | int, draft_id: int, **kwargs) -> T:
DraftFieldsValidationError: If one or more draft fields are invalid or
missing.
"""
if "draftId" in kwargs:
if "draftId" in (kwargs := self._convert_field_names(kwargs)):
raise FieldsValidationError(
"Invalid argument (reason: keyword is reserved): draftId"
)
Expand Down Expand Up @@ -322,6 +376,7 @@ def __init__(
validate_fields_fn: ValidateDraftFieldsProtocol,
root_collection: CollectionClient[T],
parent_sub_collections: list[SubCollectionClient[T]] | None = None,
convert_field_names_fn: ConvertFieldNamesToCamelCaseProtocol | None = None,
) -> None:
"""Initialize the ModifyResourceDraftsSubCollectionClient instance.
Expand All @@ -342,6 +397,9 @@ def __init__(
parent_sub_collections=parent_sub_collections,
)
self._validate_fields = validate_fields_fn
self._convert_field_names = (
convert_field_names_fn or make_field_names_to_camel_case_converter({})
)

def get_by_id(self, *resource_ids: str | int) -> T:
"""Get a resource modification draft.
Expand Down Expand Up @@ -370,7 +428,7 @@ def create(self, *resource_ids: str | int, **kwargs) -> T:
"""
return self._session.post(
self.build_sub_collection_url(*resource_ids),
json_=self._validate_fields(kwargs),
json_=self._validate_fields(self._convert_field_names(kwargs)),
)

def modify(
Expand All @@ -394,7 +452,9 @@ def modify(
self.build_sub_collection_url(*resource_ids),
json_={
"resourceSnapshot": int(resource_snapshot_id),
"resourceData": self._validate_fields(kwargs),
"resourceData": self._validate_fields(
self._convert_field_names(kwargs)
),
},
)

Expand Down
10 changes: 10 additions & 0 deletions src/dioptra/client/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
ModifyResourceDraftsSubCollectionClient,
NewResourceDraftsSubCollectionClient,
make_draft_fields_validator,
make_field_names_to_camel_case_converter,
)
from .snapshots import SnapshotsSubCollectionClient
from .tags import TagsSubCollectionClient
Expand All @@ -38,6 +39,9 @@
"queues",
"plugins",
}
FIELD_NAMES_TO_CAMEL_CASE: Final[dict[str, str]] = {
"task_graph": "taskGraph",
}

T = TypeVar("T")

Expand Down Expand Up @@ -280,6 +284,9 @@ def __init__(self, session: DioptraSession[T]) -> None:
resource_name=self.name,
),
root_collection=self,
convert_field_names_fn=make_field_names_to_camel_case_converter(
name_mapping=FIELD_NAMES_TO_CAMEL_CASE
),
)
self._modify_resource_drafts = ModifyResourceDraftsSubCollectionClient[T](
session=session,
Expand All @@ -288,6 +295,9 @@ def __init__(self, session: DioptraSession[T]) -> None:
resource_name=self.name,
),
root_collection=self,
convert_field_names_fn=make_field_names_to_camel_case_converter(
name_mapping=FIELD_NAMES_TO_CAMEL_CASE
),
)
self._snapshots = SnapshotsSubCollectionClient[T](
session=session, root_collection=self
Expand Down
122 changes: 111 additions & 11 deletions tests/unit/restapi/v1/test_entrypoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pytest
from flask_sqlalchemy import SQLAlchemy

from dioptra.client.base import DioptraResponseProtocol
from dioptra.client.base import DioptraResponseProtocol, FieldNameCollisionError
from dioptra.client.client import DioptraClient

from ..lib import helpers, routines
Expand Down Expand Up @@ -849,15 +849,15 @@ def test_manage_existing_entrypoint_draft(
draft = {
"name": name,
"description": description,
"taskGraph": task_graph,
"task_graph": task_graph,
"parameters": parameters,
"plugins": plugin_ids,
"queues": queue_ids,
}
draft_mod = {
"name": new_name,
"description": description,
"taskGraph": task_graph,
"task_graph": task_graph,
"parameters": parameters,
"plugins": plugin_ids,
"queues": queue_ids,
Expand All @@ -870,15 +870,29 @@ def test_manage_existing_entrypoint_draft(
"resource_id": entrypoint["id"],
"resource_snapshot_id": entrypoint["snapshot"],
"num_other_drafts": 0,
"payload": draft,
"payload": {
"name": name,
"description": description,
"taskGraph": task_graph,
"parameters": parameters,
"plugins": plugin_ids,
"queues": queue_ids,
},
}
draft_mod_expected = {
"user_id": auth_account["id"],
"group_id": entrypoint["group"]["id"],
"resource_id": entrypoint["id"],
"resource_snapshot_id": entrypoint["snapshot"],
"num_other_drafts": 0,
"payload": draft_mod,
"payload": {
"name": new_name,
"description": description,
"taskGraph": task_graph,
"parameters": parameters,
"plugins": plugin_ids,
"queues": queue_ids,
},
}

# Run routine: existing resource drafts tests
Expand Down Expand Up @@ -915,15 +929,15 @@ def test_manage_new_entrypoint_drafts(
"draft1": {
"name": "entrypoint1",
"description": "new entrypoint",
"taskGraph": "graph",
"task_graph": "graph",
"parameters": [],
"plugins": [],
"queues": [],
},
"draft2": {
"name": "entrypoint2",
"description": "entrypoint",
"taskGraph": "graph",
"task_graph": "graph",
"parameters": [],
"queues": [1, 3],
"plugins": [2],
Expand All @@ -932,7 +946,7 @@ def test_manage_new_entrypoint_drafts(
draft1_mod = {
"name": "draft1",
"description": "new description",
"taskGraph": "graph",
"task_graph": "graph",
"parameters": [],
"plugins": [],
"queues": [],
Expand All @@ -942,17 +956,38 @@ def test_manage_new_entrypoint_drafts(
draft1_expected = {
"user_id": auth_account["id"],
"group_id": group_id,
"payload": drafts["draft1"],
"payload": {
"name": "entrypoint1",
"description": "new entrypoint",
"taskGraph": "graph",
"parameters": [],
"plugins": [],
"queues": [],
},
}
draft2_expected = {
"user_id": auth_account["id"],
"group_id": group_id,
"payload": drafts["draft2"],
"payload": {
"name": "entrypoint2",
"description": "entrypoint",
"taskGraph": "graph",
"parameters": [],
"queues": [1, 3],
"plugins": [2],
},
}
draft1_mod_expected = {
"user_id": auth_account["id"],
"group_id": group_id,
"payload": draft1_mod,
"payload": {
"name": "draft1",
"description": "new description",
"taskGraph": "graph",
"parameters": [],
"plugins": [],
"queues": [],
},
}

# Run routine: existing resource drafts tests
Expand All @@ -967,6 +1002,71 @@ def test_manage_new_entrypoint_drafts(
)


def test_client_raises_error_on_field_name_collision(
dioptra_client: DioptraClient[DioptraResponseProtocol],
db: SQLAlchemy,
auth_account: dict[str, Any],
registered_entrypoints: dict[str, Any],
) -> None:
"""
Test that the client errors out if both task_graph and taskGraph are passed as
keyword arguments to either the create and modify draft resource sub-collection
clients.
Given an authenticated user, this test validates the following sequence of actions:
- The user prepares a payload for either the create or modify draft resource
sub-collection client. The payload contains both task_graph an taskGraph as keys.
- The user submits a create new resource request using the payload, which raises
a FieldNameCollisionError.
- The user submits a modify resource draft request using the payload, which raises
a FieldNameCollisionError.
"""
entrypoint = registered_entrypoints["entrypoint1"]
group_id = auth_account["groups"][0]["id"]
name = "draft"
description = "description"
task_graph = textwrap.dedent(
"""# my entrypoint graph
graph:
message:
my_entrypoint: $name
"""
)
parameters = [
{
"name": "my_entrypoint_param",
"defaultValue": "my_value",
"parameterType": "string",
}
]
plugin_ids = [plugin["id"] for plugin in entrypoint["plugins"]]
queue_ids = [queue["id"] for queue in entrypoint["queues"]]

draft = {
"name": name,
"description": description,
"task_graph": task_graph,
"taskGraph": task_graph,
"parameters": parameters,
"plugins": plugin_ids,
"queues": queue_ids,
}

with pytest.raises(FieldNameCollisionError):
dioptra_client.entrypoints.new_resource_drafts.create(
group_id=group_id,
**draft,
)

with pytest.raises(FieldNameCollisionError):
dioptra_client.entrypoints.modify_resource_drafts.modify(
entrypoint["id"],
resource_snapshot_id=entrypoint["snapshot"],
**draft,
)


def test_manage_entrypoint_snapshots(
dioptra_client: DioptraClient[DioptraResponseProtocol],
db: SQLAlchemy,
Expand Down

0 comments on commit c8a4c8d

Please sign in to comment.