Skip to content

Commit

Permalink
feat(restapi): add draft commit workflow
Browse files Browse the repository at this point in the history
This commit adds a new workflow for committing drafts. Committing a draft will create a new resource
snapshot and delete the draft in a single transaction.

A draft can be either a draft of a new resource or draft modifications to an existing resource. For
draft modifications, the commit will be rejected if the snapshot the draft was based off of is not
longer the most recent snapshot. In this case an error is raised that contains detailed information
about the draft, the base snapshot, and the current snapshot so that the user can reconcile their
desired changes. A draft commit can also be rejected if the resource fails to be created for any
reason (e.g. name collisions or associations to invalid resources).

It also adds the workflow to the Python client and implements new tests for the workflow.
  • Loading branch information
keithmanville committed Feb 19, 2025
1 parent 3b81da4 commit cb87a01
Show file tree
Hide file tree
Showing 20 changed files with 669 additions and 52 deletions.
9 changes: 8 additions & 1 deletion src/dioptra/client/entrypoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,13 @@
"queues",
"plugins",
}
MODIFY_DRAFT_FIELDS: Final[set[str]] = {
"name",
"description",
"taskGraph",
"parameters",
"queues",
}
FIELD_NAMES_TO_CAMEL_CASE: Final[dict[str, str]] = {
"task_graph": "taskGraph",
}
Expand Down Expand Up @@ -291,7 +298,7 @@ def __init__(self, session: DioptraSession[T]) -> None:
self._modify_resource_drafts = ModifyResourceDraftsSubCollectionClient[T](
session=session,
validate_fields_fn=make_draft_fields_validator(
draft_fields=DRAFT_FIELDS,
draft_fields=MODIFY_DRAFT_FIELDS,
resource_name=self.name,
),
root_collection=self,
Expand Down
20 changes: 20 additions & 0 deletions src/dioptra/client/workflows.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload"
SIGNATURE_ANALYSIS: Final[str] = "pluginTaskSignatureAnalysis"
DRAFT_COMMIT: Final[str] = "draftCommit"


class WorkflowsCollectionClient(CollectionClient[T]):
Expand Down Expand Up @@ -106,3 +107,22 @@ def analyze_plugin_task_signatures(self, python_code: str) -> T:
SIGNATURE_ANALYSIS,
json_={"pythonCode": python_code},
)

def commit_draft(
self,
draft_id: str | int,
) -> T:
"""
Commit a draft as a new resource snapshot.
The draft can be a draft of a new resource or a draft modifications to an
existing resource.
Args:
draft_id: The draft id, an intiger.
Returns:
A dictionary containing the contents of the new resource.
"""

return self._session.post(self.url, DRAFT_COMMIT, str(draft_id))
43 changes: 43 additions & 0 deletions src/dioptra/restapi/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@
from flask_restx import Api
from structlog.stdlib import BoundLogger

from dioptra.restapi.db import models
from dioptra.restapi.v1 import utils

LOGGER: BoundLogger = structlog.stdlib.get_logger()


Expand Down Expand Up @@ -206,6 +209,26 @@ def __init__(self, type: str, id: int):
self.resource_id = id


class DraftResourceModificationsCommitError(DioptraError):
"""The draft modifications to a resource could not be committed"""

def __init__(
self,
resource_type: str,
resource_id: int,
draft: models.DraftResource,
base_snapshot: models.ResourceSnapshot,
curr_snapshot: models.ResourceSnapshot,
):
super().__init__(
f"Draft modifications for a [{resource_type}] with id: {resource_id} "
"could not be commited."
)
self.draft = draft
self.base_snapshot = base_snapshot
self.curr_snapshot = curr_snapshot


class InvalidDraftBaseResourceSnapshotError(DioptraError):
"""The draft's base snapshot identifier is invalid."""

Expand Down Expand Up @@ -432,6 +455,26 @@ def handle_draft_already_exists(error: DraftAlreadyExistsError):
log.debug(error.to_message())
return error_result(error, http.HTTPStatus.BAD_REQUEST, {})

@api.errorhandler(DraftResourceModificationsCommitError)
def handle_draft_resource_modifications_commit_error(
error: DraftResourceModificationsCommitError,
):
log.debug(error.to_message())

return error_result(
error,
http.HTTPStatus.BAD_REQUEST,
{
"reason": f"The {error.draft.resource_type} has been modified since "
"this draft was created.",
"draft": error.draft.payload["resource_data"],
"base_snapshot_id": error.base_snapshot.resource_snapshot_id,
"curr_snapshot_id": error.curr_snapshot.resource_snapshot_id,
"base_snapshot": utils.build_resource(error.base_snapshot),
"curr_snapshot": utils.build_resource(error.curr_snapshot),
},
)

@api.errorhandler(InvalidDraftBaseResourceSnapshotError)
def handle_invalid_draft_base_resource_snapshot(
error: InvalidDraftBaseResourceSnapshotError,
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/restapi/v1/entrypoints/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ def delete(self, id: int, queueId):
EntrypointIdDraftResource = generate_resource_id_draft_endpoint(
api,
resource_name=RESOURCE_TYPE,
request_schema=EntrypointDraftSchema(exclude=["groupId"]),
request_schema=EntrypointDraftSchema(exclude=["groupId", "pluginIds"]),
)

EntrypointSnapshotsResource = generate_resource_snapshots_endpoint(
Expand Down
2 changes: 1 addition & 1 deletion src/dioptra/restapi/v1/plugins/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,11 +317,11 @@ def post(self, id: int):
parsed_obj = request.parsed_obj # type: ignore # noqa: F841

plugin_file = self._plugin_id_file_service.create(
plugin_id=id,
filename=parsed_obj["filename"],
contents=parsed_obj["contents"],
description=parsed_obj["description"],
tasks=parsed_obj["tasks"],
plugin_id=id,
log=log,
)
return utils.build_plugin_file(plugin_file)
Expand Down
4 changes: 2 additions & 2 deletions src/dioptra/restapi/v1/plugins/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,11 +689,11 @@ def __init__(

def create(
self,
plugin_id: int,
filename: str,
contents: str,
description: str,
tasks: list[dict[str, Any]],
plugin_id: int,
commit: bool = True,
**kwargs,
) -> utils.PluginFileDict:
Expand All @@ -703,11 +703,11 @@ def create(
with the PluginFile. The creator will be the current user.
Args:
plugin_id: The unique id of the plugin containing the plugin file.
filename: The name of the plugin file.
contents: The contents of the plugin file.
description: The description of the plugin file.
tasks: The tasks associated with the plugin file.
plugin_id: The unique id of the plugin containing the plugin file.
commit: If True, commit the transaction. Defaults to True.
Returns:
Expand Down
9 changes: 0 additions & 9 deletions src/dioptra/restapi/v1/shared/drafts/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -467,15 +467,6 @@ def modify(
if draft is None:
return None, num_other_drafts

# NOTE: This check disables the ability to change the base snapshot ID.
# It is scheduled to be removed as part of the draft commit workflow feature.
if draft.payload["resource_snapshot_id"] != payload["resource_snapshot_id"]:
raise InvalidDraftBaseResourceSnapshotError(
"The provided resource snapshot must match the base resource snapshot",
base_resource_snapshot_id=draft.payload["resource_snapshot_id"],
provided_resource_snapshot_id=payload["resource_snapshot_id"],
)

if draft.payload["resource_snapshot_id"] > payload["resource_snapshot_id"]:
raise InvalidDraftBaseResourceSnapshotError(
"The provided resource snapshot must be greater than or equal to "
Expand Down
185 changes: 185 additions & 0 deletions src/dioptra/restapi/v1/shared/resource_service.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
# This Software (Dioptra) is being made available as a public service by the
# National Institute of Standards and Technology (NIST), an Agency of the United
# States Department of Commerce. This software was developed in part by employees of
# NIST and in part by NIST contractors. Copyright in portions of this software that
# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant
# to Title 17 United States Code Section 105, works of NIST employees are not
# subject to copyright protection in the United States. However, NIST may hold
# international copyright in software created by its employees and domestic
# copyright (or licensing rights) in portions of software that were assigned or
# licensed to NIST. To the extent that NIST holds copyright in this software, it is
# being made available under the Creative Commons Attribution 4.0 International
# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts
# of the software developed or licensed by NIST.
#
# ACCESS THE FULL CC BY 4.0 LICENSE HERE:
# https://creativecommons.org/licenses/by/4.0/legalcode
"""The server-side functions that perform resource operations."""
from __future__ import annotations

from typing import Any

import structlog
from injector import inject
from structlog.stdlib import BoundLogger

from dioptra.restapi.errors import DioptraError
from dioptra.restapi.v1.artifacts.service import ArtifactIdService, ArtifactService
from dioptra.restapi.v1.entrypoints.service import (
EntrypointIdService,
EntrypointService,
)
from dioptra.restapi.v1.experiments.service import (
ExperimentIdService,
ExperimentService,
)
from dioptra.restapi.v1.models.service import ModelIdService, ModelService
from dioptra.restapi.v1.plugin_parameter_types.service import (
PluginParameterTypeIdService,
PluginParameterTypeService,
)
from dioptra.restapi.v1.plugins.service import (
PluginIdFileIdService,
PluginIdFileService,
PluginIdService,
PluginService,
)
from dioptra.restapi.v1.queues.service import QueueIdService, QueueService

LOGGER: BoundLogger = structlog.stdlib.get_logger()


class ResourceService(object):
"""The service methods for creating resources."""

@inject
def __init__(
self,
artifact_service: ArtifactService,
entrypoint_service: EntrypointService,
experiment_service: ExperimentService,
model_service: ModelService,
plugin_service: PluginService,
plugin_id_file_service: PluginIdFileService,
plugin_parameter_type_service: PluginParameterTypeService,
queue_service: QueueService,
) -> None:
"""Initialize the queue service.
All arguments are provided via dependency injection.
Args:
queue_name_service: A QueueNameService object.
"""

self._services: dict[str, object] = {
"artifact": artifact_service,
"entry_point": entrypoint_service,
"experiment": experiment_service,
"ml_model": model_service,
"plugin": plugin_service,
"plugin_file": plugin_id_file_service,
"plugin_task_parameter_type": plugin_parameter_type_service,
"queue": queue_service,
}

def create(
self,
*resource_ids: int,
resource_type: str,
resource_data: dict,
group_id: int,
commit: bool = True,
**kwargs,
) -> Any:
"""Create a new queue.
Args:
resource_type: The type of resource to create, e.g. queue
resource_data: Arguments passed to create services as kwargs
group_id: The group that will own the queue.
commit: If True, commit the transaction. Defaults to True.
Returns:
The newly created queue object.
Raises:
EntityExistsError: If a queue with the given name already exists.
EntityDoesNotExistError: If the group with the provided ID does not exist.
"""
log: BoundLogger = kwargs.get("log", LOGGER.new())

if resource_type not in self._services:
raise DioptraError(f"Invalid resource type: {resource_type}")

return self._services[resource_type].create( # type: ignore
*resource_ids, group_id=group_id, **resource_data, commit=commit, log=log
)


class ResourceIdService(object):
"""The service methods for creating resources."""

@inject
def __init__(
self,
artifact_id_service: ArtifactIdService,
entrypoint_id_service: EntrypointIdService,
experiment_id_service: ExperimentIdService,
model_id_service: ModelIdService,
plugin_id_service: PluginIdService,
plugin_id_file_id_service: PluginIdFileIdService,
plugin_parameter_type_id_service: PluginParameterTypeIdService,
queue_id_service: QueueIdService,
) -> None:
"""Initialize the resource id service.
All arguments are provided via dependency injection.
Args:
queue_name_service: A QueueNameService object.
"""

self._services = {
"artifact": artifact_id_service,
"entry_point": entrypoint_id_service,
"experiment": experiment_id_service,
"ml_model": model_id_service,
"plugin": plugin_id_service,
"plugin_file": plugin_id_file_id_service,
"plugin_task_parameter_type": plugin_parameter_type_id_service,
"queue": queue_id_service,
}

def modify(
self,
*resource_ids: int,
resource_type: str,
resource_data: dict,
group_id: int,
commit: bool = True,
**kwargs,
) -> Any:
"""Create a new queue.
Args:
resource_type: The type of resource to create, e.g. queue
resource_data: Arguments passed to create services as kwargs
group_id: The group that will own the queue.
commit: If True, commit the transaction. Defaults to True.
Returns:
The newly created queue object.
Raises:
EntityExistsError: If a queue with the given name already exists.
EntityDoesNotExistError: If the group with the provided ID does not exist.
"""
log: BoundLogger = kwargs.get("log", LOGGER.new())

if resource_type not in self._services:
raise DioptraError(f"Invalid resource type: {resource_type}")

return self._services[resource_type].modify( # type: ignore
*resource_ids, group_id=group_id, **resource_data, commit=commit, log=log
)
Loading

0 comments on commit cb87a01

Please sign in to comment.