From 3b81da427b2a6688a002dc977da2da6bcca70cfd Mon Sep 17 00:00:00 2001 From: jtsextonMITRE <45762017+jtsextonMITRE@users.noreply.github.com> Date: Mon, 17 Feb 2025 12:34:26 -0500 Subject: [PATCH] feat(restapi): add workflow for plugin task signature analysis This commit exposes the signature analysis module as a workflow in the Dioptra REST API. It updates the Python client to support the new workflow. It also includes a new test suite the provides test coverage for the workflow and client. --- src/dioptra/client/workflows.py | 20 + .../restapi/v1/workflows/controller.py | 43 +- src/dioptra/restapi/v1/workflows/schema.py | 81 +++ src/dioptra/restapi/v1/workflows/service.py | 52 +- .../signature_analysis/sample_test_alias.py | 22 + .../sample_test_complex_type.py | 22 + .../sample_test_function_type.py | 22 + .../sample_test_none_return.py | 22 + .../sample_test_optional.py | 22 + .../sample_test_pyplugs_alias.py | 22 + .../sample_test_real_world.py | 424 ++++++++++++++ .../sample_test_redefinition.py | 54 ++ .../sample_test_register_alias.py | 22 + .../sample_test_type_conflict.py | 22 + .../v1/workflows/test_signature_analysis.py | 522 ++++++++++++++++++ 15 files changed, 1368 insertions(+), 4 deletions(-) create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py create mode 100644 tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py create mode 100644 tests/unit/restapi/v1/workflows/test_signature_analysis.py diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 8dfa4f6c6..2db601624 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -22,6 +22,7 @@ T = TypeVar("T") JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload" +SIGNATURE_ANALYSIS: Final[str] = "pluginTaskSignatureAnalysis" class WorkflowsCollectionClient(CollectionClient[T]): @@ -86,3 +87,22 @@ def download_job_files( return self._session.download( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) + + def analyze_plugin_task_signatures(self, python_code: str) -> T: + """ + Requests signature analysis for the functions in an annotated python file. + + Args: + python_code: The contents of the python file. + filename: The name of the file. + + Returns: + The response from the Dioptra API. + + """ + + return self._session.post( + self.url, + SIGNATURE_ANALYSIS, + json_={"pythonCode": python_code}, + ) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..55024531d 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -19,14 +19,19 @@ import structlog from flask import request, send_file -from flask_accepts import accepts +from flask_accepts import accepts, responds from flask_login import login_required from flask_restx import Namespace, Resource from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from .schema import ( + FileTypes, + JobFilesDownloadQueryParametersSchema, + SignatureAnalysisOutputSchema, + SignatureAnalysisSchema, +) +from .service import JobFilesDownloadService, SignatureAnalysisService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +83,35 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/pluginTaskSignatureAnalysis") +class SignatureAnalysisEndpoint(Resource): + @inject + def __init__( + self, signature_analysis_service: SignatureAnalysisService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + signature_analysis_service: A SignatureAnalysisService object. + """ + self._signature_analysis_service = signature_analysis_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(schema=SignatureAnalysisSchema, api=api) + @responds(schema=SignatureAnalysisOutputSchema, api=api) + def post(self): + """Download a compressed file archive containing the files needed to execute a submitted job.""" # noqa: B950 + log = LOGGER.new( # noqa: F841 + request_id=str(uuid.uuid4()), + resource="SignatureAnalysis", + request_type="POST", + ) + parsed_obj = request.parsed_obj + return self._signature_analysis_service.post( + python_code=parsed_obj["python_code"], + ) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 92ea28ec7..505d4cdb7 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -41,3 +41,84 @@ class JobFilesDownloadQueryParametersSchema(Schema): by_value=True, default=FileTypes.TAR_GZ.value, ) + + +class SignatureAnalysisSchema(Schema): + + pythonCode = fields.String( + attribute="python_code", + metadata=dict(description="The contents of the python file"), + ) + + +class SignatureAnalysisSignatureParamSchema(Schema): + name = fields.String( + attribute="name", metadata=dict(description="The name of the parameter") + ) + type = fields.String( + attribute="type", metadata=dict(description="The type of the parameter") + ) + + +class SignatureAnalysisSignatureInputSchema(SignatureAnalysisSignatureParamSchema): + required = fields.Boolean( + attribute="required", + metadata=dict(description="Whether this is a required parameter"), + ) + + +class SignatureAnalysisSignatureOutputSchema(SignatureAnalysisSignatureParamSchema): + pass + + +class SignatureAnalysisSuggestedTypes(Schema): + + # add proposed_type in next iteration + + name = fields.String( + attribute="name", + metadata=dict(description="A suggestion for the name of the type"), + ) + + description = fields.String( + attribute="description", + metadata=dict( + description="The annotation the suggestion is attempting to represent" + ), + ) + + +class SignatureAnalysisSignatureSchema(Schema): + name = fields.String( + attribute="name", metadata=dict(description="The name of the function") + ) + inputs = fields.Nested( + SignatureAnalysisSignatureInputSchema, + metadata=dict(description="A list of objects describing the input parameters."), + many=True, + ) + outputs = fields.Nested( + SignatureAnalysisSignatureOutputSchema, + metadata=dict( + description="A list of objects describing the output parameters." + ), + many=True, + ) + missing_types = fields.Nested( + SignatureAnalysisSuggestedTypes, + metadata=dict( + description="A list of missing types for non-primitives defined by the file" + ), + many=True, + ) + + +class SignatureAnalysisOutputSchema(Schema): + tasks = fields.Nested( + SignatureAnalysisSignatureSchema, + metadata=dict( + description="A list of signature analyses for the plugin tasks " + "provided in the input file" + ), + many=True, + ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d5769e274..074ba4106 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,11 +15,13 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +from typing import IO, Any, Final import structlog from structlog.stdlib import BoundLogger +from dioptra.restapi.v1.shared.signature_analysis import get_plugin_signatures + from .lib import views from .lib.package_job_files import package_job_files from .schema import FileTypes @@ -65,3 +67,51 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + + +class SignatureAnalysisService(object): + """The service methods for performing signature analysis on a file.""" + + def post(self, python_code: str, **kwargs) -> dict[str, list[dict[str, Any]]]: + """Perform signature analysis on a file. + + Args: + filename: The name of the file. + python_code: The contents of the file. + + Returns: + A dictionary containing the signature analysis. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Performing signature analysis", + python_source=python_code, + ) + endpoint_analyses = [ + _create_endpoint_analysis_dict(signature) + for signature in get_plugin_signatures(python_source=python_code) + ] + return {"tasks": endpoint_analyses} + + +def _create_endpoint_analysis_dict( + signature: dict[str, Any], +) -> dict[str, Any]: + """Create an endpoint analysis dictionary from a signature analysis. + Args: + signature: The signature analysis. + Returns: + The endpoint analysis dictionary. + """ + return { + "name": signature["name"], + "inputs": signature["inputs"], + "outputs": signature["outputs"], + "missing_types": [ + { + "description": suggested_type["type_annotation"], + "name": suggested_type["suggestion"], + } + for suggested_type in signature["suggested_types"] + ], + } diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py new file mode 100644 index 000000000..904d2cf65 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs as foo + + +@foo.register +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py new file mode 100644 index 000000000..f2833120a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register() +def the_plugin(arg1: Optional[str]) -> Union[int, bool]: + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py new file mode 100644 index 000000000..bc3242674 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2)) -> foo(2): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py new file mode 100644 index 000000000..0ed95097e --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.pyplugs import register + + +@register +def my_plugin() -> None: + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py new file mode 100644 index 000000000..ec847c6ea --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra import pyplugs + + +@pyplugs.register() +def do_things(arg1: Optional[str], arg2: int = 123): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py new file mode 100644 index 000000000..73ab9039a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra import pyplugs as foo + + +@foo.register +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py new file mode 100644 index 000000000..79689c7ef --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py @@ -0,0 +1,424 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import scipy.stats +import structlog +from structlog.stdlib import BoundLogger +from tensorflow.keras.preprocessing.image import DirectoryIterator + +import mlflow +from dioptra import pyplugs + +from .artifacts_mlflow import ( + download_all_artifacts, + upload_data_frame_artifact, + upload_directory_as_tarball_artifact, +) +from .artifacts_restapi import ( + get_uri_for_model, + get_uris_for_artifacts, + get_uris_for_job, +) +from .artifacts_utils import extract_tarfile, make_directories +from .attacks_fgm import fgm +from .attacks_patch import create_adversarial_patch_dataset, create_adversarial_patches +from .backend_configs_tensorflow import init_tensorflow +from .data_tensorflow import ( + create_image_dataset, + df_to_predictions, + get_n_classes_from_directory_iterator, + predictions_to_df, +) +from .defenses_image_preprocessing import create_defended_dataset +from .estimators_keras_classifiers import init_classifier +from .estimators_methods import fit +from .metrics_distance import get_distance_metric_list +from .metrics_performance import evaluate_metrics_generic, get_performance_metric_list +from .mlflow import add_model_to_registry +from .random_rng import init_rng +from .random_sample import draw_random_integer +from .registry_art import load_wrapped_tensorflow_keras_classifier +from .registry_mlflow import load_tensorflow_keras_classifier +from .tensorflow import ( + evaluate_metrics_tensorflow, + get_model_callbacks, + get_optimizer, + get_performance_metrics, + predict_tensorflow, +) +from .tracking_mlflow import log_metrics, log_parameters, log_tensorflow_keras_estimator + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +@pyplugs.register +def load_dataset( + ep_seed: int = 10145783023, + training_dir: str = "/dioptra/data/Mnist/training", + testing_dir: str = "/dioptra/data/Mnist/testing", + subsets: List[str] = ["testing"], + image_size: Tuple[int, int, int] = [28, 28, 1], + rescale: float = 1.0 / 255, + validation_split: Optional[float] = 0.2, + batch_size: int = 32, + label_mode: str = "categorical", + shuffle: bool = False, +) -> DirectoryIterator: + seed, rng = init_rng(ep_seed) + global_seed = draw_random_integer(rng) + dataset_seed = draw_random_integer(rng) + init_tensorflow(global_seed) + log_parameters( + { + "entry_point_seed": ep_seed, + "tensorflow_global_seed": global_seed, + "dataset_seed": dataset_seed, + } + ) + training_dataset = ( + None + if "training" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="training", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + + validation_dataset = ( + None + if "validation" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="validation", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + testing_dataset = ( + None + if "testing" not in subsets + else create_image_dataset( + data_dir=testing_dir, + subset=None, + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + return training_dataset, validation_dataset, testing_dataset + + +@pyplugs.register +def create_model( + dataset: DirectoryIterator = None, + model_architecture: str = "le_net", + input_shape: Tuple[int, int, int] = [28, 28, 1], + loss: str = "categorical_crossentropy", + learning_rate: float = 0.001, + optimizer: str = "Adam", + metrics_list: List[Dict[str, Any]] = None, +): + n_classes = get_n_classes_from_directory_iterator(dataset) + optim = get_optimizer(optimizer, learning_rate) + perf_metrics = get_performance_metrics(metrics_list) + classifier = init_classifier( + model_architecture, optim, perf_metrics, input_shape, n_classes, loss + ) + return classifier + + +@pyplugs.register +def load_model( + model_name: str | None = None, + model_version: int | None = None, + imagenet_preprocessing: bool = False, + art: bool = False, + image_size: Any = None, + classifier_kwargs: Optional[Dict[str, Any]] = None, +): + uri = get_uri_for_model(model_name, model_version) + if art: + classifier = load_wrapped_tensorflow_keras_classifier( + uri, imagenet_preprocessing, image_size, classifier_kwargs + ) + else: + classifier = load_tensorflow_keras_classifier(uri) + return classifier + + +@pyplugs.register +def train( + estimator: Any, + x: Any = None, + y: Any = None, + callbacks_list: List[Dict[str, Any]] = None, + fit_kwargs: Optional[Dict[str, Any]] = None, +): + fit_kwargs = {} if fit_kwargs is None else fit_kwargs + callbacks = get_model_callbacks(callbacks_list) + fit_kwargs["callbacks"] = callbacks + fit(estimator=estimator, x=x, y=y, fit_kwargs=fit_kwargs) + return estimator + + +@pyplugs.register +def save_artifacts_and_models( + artifacts: List[Dict[str, Any]] = None, models: List[Dict[str, Any]] = None +): + artifacts = [] if artifacts is None else artifacts + models = [] if models is None else models + + for model in models: + log_tensorflow_keras_estimator(model["model"], "model") + add_model_to_registry(model["name"], "model") + for artifact in artifacts: + if artifact["type"] == "tarball": + upload_directory_as_tarball_artifact( + source_dir=artifact["adv_data_dir"], + tarball_filename=artifact["adv_tar_name"], + ) + if artifact["type"] == "dataframe": + upload_data_frame_artifact( + data_frame=artifact["data_frame"], + file_name=artifact["file_name"], + file_format=artifact["file_format"], + file_format_kwargs=artifact["file_format_kwargs"], + ) + + +@pyplugs.register +def load_artifacts_for_job( + job_id: str, files: List[str | Path] = None, extract_files: List[str | Path] = None +): + files = [] if files is None else files + extract_files = [] if extract_files is None else extract_files + files += extract_files # need to download them to be able to extract + + uris = get_uris_for_job(job_id) + paths = download_all_artifacts(uris, files) + for extract in paths: + for ef in extract_files: + if ef.endswith(str(ef)): + extract_tarfile(extract) + return paths + + +@pyplugs.register +def load_artifacts( + artifact_ids: List[int] = None, extract_files: List[str | Path] = None +): + extract_files = [] if extract_files is None else extract_files + artifact_ids = [] if artifact_ids is not None else artifact_ids + uris = get_uris_for_artifacts(artifact_ids) + paths = download_all_artifacts(uris, extract_files) + for extract in paths: + extract_tarfile(extract) + + +@pyplugs.register +def attack_fgm( + dataset: Any, + adv_data_dir: Union[str, Path], + classifier: Any, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + eps: float = 0.3, + eps_step: float = 0.1, + minimal: bool = False, + norm: Union[int, float, str] = np.inf, +): + """generate fgm examples""" + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + fgm_dataset = fgm( + data_flow=dataset, + adv_data_dir=adv_data_dir, + keras_classifier=classifier, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + eps=eps, + eps_step=eps_step, + minimal=minimal, + norm=norm, + ) + return fgm_dataset + + +@pyplugs.register() +def attack_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + model: Any, + patch_target: int, + num_patch: int, + num_patch_samples: int, + rotation_max: float, + scale_min: float, + scale_max: float, + learning_rate: float, + max_iter: int, + patch_shape: Tuple, +): + """generate patches""" + make_directories([adv_data_dir]) + create_adversarial_patches( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + keras_classifier=model, + patch_target=patch_target, + num_patch=num_patch, + num_patch_samples=num_patch_samples, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max, + learning_rate=learning_rate, + max_iter=max_iter, + patch_shape=patch_shape, + ) + + +@pyplugs.register() +def augment_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + patch_dir: Union[str, Path], + model: Any, + patch_shape: Tuple, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + patch_scale: float = 0.4, + rotation_max: float = 22.5, + scale_min: float = 0.1, + scale_max: float = 1.0, +): + """add patches to a dataset""" + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + create_adversarial_patch_dataset( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + patch_dir=patch_dir, + keras_classifier=model, + patch_shape=patch_shape, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + patch_scale=patch_scale, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max, + ) + + +@pyplugs.register +def model_metrics(classifier: Any, dataset: Any): + metrics = evaluate_metrics_tensorflow(classifier, dataset) + log_metrics(metrics) + return metrics + + +@pyplugs.register +def prediction_metrics( + y_true: np.ndarray, + y_pred: np.ndarray, + metrics_list: List[Dict[str, str]], + func_kwargs: Dict[str, Dict[str, Any]] = None, +): + func_kwargs = {} if func_kwargs is None else func_kwargs + callable_list = get_performance_metric_list(metrics_list) + metrics = evaluate_metrics_generic(y_true, y_pred, callable_list, func_kwargs) + log_metrics(metrics) + return pd.DataFrame(metrics, index=[1]) + + +@pyplugs.register +def augment_data( + dataset: Any, + def_data_dir: Union[str, Path], + image_size: Tuple[int, int, int], + distance_metrics: List[Dict[str, str]], + batch_size: int = 50, + def_type: str = "spatial_smoothing", + defense_kwargs: Optional[Dict[str, Any]] = None, +): + make_directories([def_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + defended_dataset = create_defended_dataset( + data_flow=dataset, + def_data_dir=def_data_dir, + image_size=image_size, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + def_type=def_type, + defense_kwargs=defense_kwargs, + ) + return defended_dataset + + +@pyplugs.register +def predict( + classifier: Any, + dataset: Any, + show_actual: bool = False, + show_target: bool = False, +): + predictions = predict_tensorflow(classifier, dataset) + df = predictions_to_df( + predictions, dataset, show_actual=show_actual, show_target=show_target + ) + return df + + +@pyplugs.register +def load_predictions( + paths: List[str], + filename: str, + format: str = "csv", + dataset: DirectoryIterator = None, + n_classes: int = -1, +): + loc = None + for m in paths: + if m.endswith(filename): + loc = m + if format == "csv": + df = pd.read_csv(loc) + elif format == "json": + df = pd.read_json(loc) + y_true, y_pred = df_to_predictions(df, dataset, n_classes) + return y_true, y_pred diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py new file mode 100644 index 000000000..8978be0a0 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py @@ -0,0 +1,54 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import aaa + +from dioptra.pyplugs import register + + +@register() +def test_plugin(): + pass + + +@aaa.register +def not_a_plugin(): + pass + + +class SomeClass: + pass + + +def some_other_func(): + pass + + +x = 1 + + +@register +def test_plugin2(): + pass + + +# re-definition of the "register" symbol +from bbb import ccc as register + + +@register +def also_not_a_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py new file mode 100644 index 000000000..b5ab0d362 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.pyplugs import register as foo + + +@foo +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py new file mode 100644 index 000000000..0282d7703 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): + pass diff --git a/tests/unit/restapi/v1/workflows/test_signature_analysis.py b/tests/unit/restapi/v1/workflows/test_signature_analysis.py new file mode 100644 index 000000000..e9e43b86a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/test_signature_analysis.py @@ -0,0 +1,522 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from http import HTTPStatus +from pathlib import Path +from typing import Any + +from flask_sqlalchemy import SQLAlchemy + +from dioptra.client.base import DioptraResponseProtocol +from dioptra.client.client import DioptraClient + +expected_outputs = {} + +expected_outputs["sample_test_real_world.py"] = [ + { + "name": "load_dataset", + "inputs": [ + {"name": "ep_seed", "type": "integer", "required": False}, + {"name": "training_dir", "type": "string", "required": False}, + {"name": "testing_dir", "type": "string", "required": False}, + {"name": "subsets", "type": "list_str", "required": False}, + {"name": "image_size", "type": "tuple_int_int_int", "required": False}, + {"name": "rescale", "type": "number", "required": False}, + {"name": "validation_split", "type": "optional_float", "required": False}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "label_mode", "type": "string", "required": False}, + {"name": "shuffle", "type": "boolean", "required": False}, + ], + "outputs": [{"name": "output", "type": "directoryiterator"}], + "missing_types": [ + {"name": "list_str", "description": "List[str]"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + {"name": "optional_float", "description": "Optional[float]"}, + {"name": "directoryiterator", "description": "DirectoryIterator"}, + ], + }, + { + "name": "create_model", + "inputs": [ + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "model_architecture", "type": "string", "required": False}, + {"name": "input_shape", "type": "tuple_int_int_int", "required": False}, + {"name": "loss", "type": "string", "required": False}, + {"name": "learning_rate", "type": "number", "required": False}, + {"name": "optimizer", "type": "string", "required": False}, + {"name": "metrics_list", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "directoryiterator", "description": "DirectoryIterator"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + }, + ], + }, + { + "name": "load_model", + "inputs": [ + {"name": "model_name", "type": "str_none", "required": False}, + {"name": "model_version", "type": "int_none", "required": False}, + {"name": "imagenet_preprocessing", "type": "boolean", "required": False}, + {"name": "art", "type": "boolean", "required": False}, + {"name": "image_size", "type": "any", "required": False}, + { + "name": "classifier_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "str_none", "description": "str | None"}, + {"name": "int_none", "description": "int | None"}, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "train", + "inputs": [ + {"name": "estimator", "type": "any", "required": True}, + {"name": "x", "type": "any", "required": False}, + {"name": "y", "type": "any", "required": False}, + {"name": "callbacks_list", "type": "list_dict_str_any", "required": False}, + {"name": "fit_kwargs", "type": "optional_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + }, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "save_artifacts_and_models", + "inputs": [ + {"name": "artifacts", "type": "list_dict_str_any", "required": False}, + {"name": "models", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + } + ], + }, + { + "name": "load_artifacts_for_job", + "inputs": [ + {"name": "job_id", "type": "string", "required": True}, + {"name": "files", "type": "list_str_path", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_str_path", "description": "List[str | Path]"} + ], + }, + { + "name": "load_artifacts", + "inputs": [ + {"name": "artifact_ids", "type": "list_int", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_int", "description": "List[int]"}, + {"name": "list_str_path", "description": "List[str | Path]"}, + ], + }, + { + "name": "attack_fgm", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "classifier", "type": "any", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "eps", "type": "number", "required": False}, + {"name": "eps_step", "type": "number", "required": False}, + {"name": "minimal", "type": "boolean", "required": False}, + {"name": "norm", "type": "union_int_float_str", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "union_int_float_str", + "description": "Union[int, float, str]", + }, + ], + }, + { + "name": "attack_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_target", "type": "integer", "required": True}, + {"name": "num_patch", "type": "integer", "required": True}, + {"name": "num_patch_samples", "type": "integer", "required": True}, + {"name": "rotation_max", "type": "number", "required": True}, + {"name": "scale_min", "type": "number", "required": True}, + {"name": "scale_max", "type": "number", "required": True}, + {"name": "learning_rate", "type": "number", "required": True}, + {"name": "max_iter", "type": "integer", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + {"name": "tuple", "description": "Tuple"}, + ], + }, + { + "name": "augment_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "patch_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "patch_scale", "type": "number", "required": False}, + {"name": "rotation_max", "type": "number", "required": False}, + {"name": "scale_min", "type": "number", "required": False}, + {"name": "scale_max", "type": "number", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + {"name": "tuple", "description": "Tuple"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + ], + }, + { + "name": "model_metrics", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "prediction_metrics", + "inputs": [ + {"name": "y_true", "type": "np_ndarray", "required": True}, + {"name": "y_pred", "type": "np_ndarray", "required": True}, + {"name": "metrics_list", "type": "list_dict_str_str", "required": True}, + {"name": "func_kwargs", "type": "dict_str_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "np_ndarray", "description": "np.ndarray"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "dict_str_dict_str_any", + "description": "Dict[str, Dict[str, Any]]", + }, + ], + }, + { + "name": "augment_data", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "def_data_dir", "type": "union_str_path", "required": True}, + {"name": "image_size", "type": "tuple_int_int_int", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "def_type", "type": "string", "required": False}, + { + "name": "defense_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "predict", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + {"name": "show_actual", "type": "boolean", "required": False}, + {"name": "show_target", "type": "boolean", "required": False}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "load_predictions", + "inputs": [ + {"name": "paths", "type": "list_str", "required": True}, + {"name": "filename", "type": "string", "required": True}, + {"name": "format", "type": "string", "required": False}, + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "n_classes", "type": "integer", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_str", "description": "List[str]"}, + {"name": "directoryiterator", "description": "DirectoryIterator"}, + ], + }, +] + +expected_outputs["sample_test_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_complex_type.py"] = [ + { + "name": "the_plugin", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "union_int_bool"}], + "missing_types": [ + {"name": "optional_str", "description": "Optional[str]"}, + {"name": "union_int_bool", "description": "Union[int, bool]"}, + ], + } +] + +expected_outputs["sample_test_function_type.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type1", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "type1"}], + "missing_types": [ + {"name": "type1", "description": "foo(2)"}, + ], + } +] + +expected_outputs["sample_test_none_return.py"] = [ + {"name": "my_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_optional.py"] = [ + { + "name": "do_things", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + }, + { + "name": "arg2", + "type": "integer", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "optional_str", "description": "Optional[str]"}, + ], + } +] + +expected_outputs["sample_test_pyplugs_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_redefinition.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []}, + {"name": "test_plugin2", "inputs": [], "outputs": [], "missing_types": []}, +] + +expected_outputs["sample_test_register_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_type_conflict.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type2", + "required": True, + }, + { + "name": "arg2", + "type": "type1", + "required": True, + }, + ], + "outputs": [{"name": "output", "type": "type2"}], + "missing_types": [ + {"name": "type2", "description": "foo(2)"}, + {"name": "type1", "description": "Type1"}, + ], + } +] + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_signature_analysis_response_matches_expectations( + response: dict[str, Any], expected_contents: dict[str, Any] +) -> None: + """Assert that a job response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + # Check expected keys + expected_keys = { + "name", + "missing_types", + "outputs", + "inputs", + } + assert set(response.keys()) == expected_keys + + # Check basic response types + assert isinstance(response["name"], str) + assert isinstance(response["outputs"], list) + assert isinstance(response["missing_types"], list) + assert isinstance(response["inputs"], list) + + def sort_by_name(lst, k="name"): + return sorted(lst, key=lambda x: x[k]) + + assert sort_by_name(response["outputs"]) == sort_by_name( + expected_contents["outputs"] + ) + assert sort_by_name(response["inputs"]) == sort_by_name(expected_contents["inputs"]) + assert sort_by_name(response["missing_types"], k="name") == sort_by_name( + expected_contents["missing_types"], k="name" + ) + + +def assert_signature_analysis_responses_matches_expectations( + responses: list[dict[str, Any]], expected_contents: list[dict[str, Any]] +) -> None: + assert len(responses) == len(expected_contents) + for response in responses: + assert_signature_analysis_response_matches_expectations( + response, [a for a in expected_contents if a["name"] == response["name"]][0] + ) + + +def assert_signature_analysis_file_load_and_contents( + dioptra_client: DioptraClient[DioptraResponseProtocol], + filename: str, +): + location = Path("tests/unit/restapi/v1/workflows/signature_analysis") / filename + + with location.open("r") as f: + contents = f.read() + + contents_analysis = dioptra_client.workflows.analyze_plugin_task_signatures( + python_code=contents, + ) + + assert contents_analysis.status_code == HTTPStatus.OK + + + print(contents_analysis.json()) + assert_signature_analysis_responses_matches_expectations( + contents_analysis.json()["tasks"], + expected_contents=expected_outputs[filename], + ) + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_signature_analysis( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """ + Test that signature analysis + Args: + client: The Flask test client. + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + + for fn in expected_outputs: + assert_signature_analysis_file_load_and_contents( + dioptra_client=dioptra_client, filename=fn + )