diff --git a/src/dioptra/client/workflows.py b/src/dioptra/client/workflows.py index 8dfa4f6c6..2db601624 100644 --- a/src/dioptra/client/workflows.py +++ b/src/dioptra/client/workflows.py @@ -22,6 +22,7 @@ T = TypeVar("T") JOB_FILES_DOWNLOAD: Final[str] = "jobFilesDownload" +SIGNATURE_ANALYSIS: Final[str] = "pluginTaskSignatureAnalysis" class WorkflowsCollectionClient(CollectionClient[T]): @@ -86,3 +87,22 @@ def download_job_files( return self._session.download( self.url, JOB_FILES_DOWNLOAD, output_path=job_files_path, params=params ) + + def analyze_plugin_task_signatures(self, python_code: str) -> T: + """ + Requests signature analysis for the functions in an annotated python file. + + Args: + python_code: The contents of the python file. + filename: The name of the file. + + Returns: + The response from the Dioptra API. + + """ + + return self._session.post( + self.url, + SIGNATURE_ANALYSIS, + json_={"pythonCode": python_code}, + ) diff --git a/src/dioptra/restapi/v1/shared/signature_analysis.py b/src/dioptra/restapi/v1/shared/signature_analysis.py new file mode 100644 index 000000000..0ad7aec66 --- /dev/null +++ b/src/dioptra/restapi/v1/shared/signature_analysis.py @@ -0,0 +1,739 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +""" +Extract task plugin function signature information from Python source code. +""" +import ast as ast_module # how many variables named "ast" might we have... +import itertools +import re +import sys +from pathlib import Path +from typing import Any, Container, Iterator, Optional, Union + +from dioptra.task_engine import type_registry + +_PYTHON_TO_DIOPTRA_TYPE_NAME = { + "str": "string", + "int": "integer", + "float": "number", + "bool": "boolean", + "None": "null", +} + + +def _is_constant(ast: ast_module.AST, value: Any) -> bool: + """ + Determine whether the given AST node represents a constant (literal) of + the given value. + + Args: + ast: An AST node + value: A value to compare to + + Returns: + True if the AST node is a constant of the given value; False if not + """ + return isinstance(ast, ast_module.Constant) and ast.value == value + + +def _is_simple_dotted_name(node: ast_module.AST) -> bool: + """ + Determine whether the given AST node represents a simple name or dotted + name, like "foo", "foo.bar", "foo.bar.baz", etc. + + Args: + node: The AST node + + Returns: + True if the node represents a simple dotted name; False if not + """ + return isinstance(node, ast_module.Name) or ( + isinstance(node, ast_module.Attribute) and _is_simple_dotted_name(node.value) + ) + + +def _update_symbols(symbol_tree: dict[str, Any], name: str) -> dict[str, Any]: + """ + Update/modify the given symbol tree such that it includes the given + name. + + The symbol tree is conceptually roughly a symbol hierarchy. This is how + modules and other types of values are naturally arranged in Python. An + import statement (assuming it is correct, and in the absence of any way or + desire to check, we assume they are all correct) reflects this hierarchy, + and the hierarchy may be inferred from it. + + It is implemented as a nested dict of dicts. The dicts map a symbol name + to other dicts which may have other symbol names, which map to other dicts, + etc. One can look up a symbol to get a "value", but we don't actually have + access to any runtime values. A symbol's "value" in this tree will be + whatever dict it maps to (which may be empty). + + Importantly, aliasing present in import statements ("as" clauses) is + reflected in the symbol tree by referring to the same dict in multiple + places. This means the structure is not technically a tree, since nodes + can have in-degree greater than one. But it makes aliasing trivial to + deal with: you can use the "is" operator to check whether two symbols' + "values" are the same. + + Args: + symbol_tree: A symbol tree structure to update + name: The name to update the tree with + + Returns: + The resulting "value" of the symbol after the tree has been updated + """ + names = name.split(".") + + curr_mod = symbol_tree + for symbol_name in names: + curr_mod = curr_mod.setdefault(symbol_name, {}) + + return curr_mod + + +def _look_up_symbol( + symbol_tree: Optional[dict[str, Any]], name: str +) -> Optional[dict[str, Any]]: + """ + Look up a symbol in the given symbol tree and return its "value". The + symbol tree data structure is comprised of nested dicts, so the value + returned (if the symbol is found) is always a dict. + + Args: + symbol_tree: A symbol tree structure + name: The name to look up, as a string. E.g. "foo", "foo.bar", etc. + + Returns: + The value of the given symbol, or None if it was not found in the + symbol tree + """ + if not name: + # Just in case... + raise ValueError("Symbol name must not be null/empty") + + if not symbol_tree: + result = None + else: + dot_idx = name.find(".") + if dot_idx == -1: + result = symbol_tree.get(name) + else: + result = _look_up_symbol( + symbol_tree.get(name[:dot_idx]), name[dot_idx + 1 :] + ) + + return result + + +def _are_aliases(symbol_tree: dict[str, Any], name1: str, name2: str) -> bool: + """ + Determine whether two symbol names refer to the same value. + + Args: + symbol_tree: A symbol tree structure + name1: A symbol name + name2: A symbol name + + Returns: + True if both symbol names are defined and resolve to the same value; + False if not + """ + name1_value = _look_up_symbol(symbol_tree, name1) + name2_value = _look_up_symbol(symbol_tree, name2) + + return ( + name1_value is not None + and name2_value is not None + and name1_value is name2_value + ) + + +def _process_import(stmt: ast_module.AST, symbol_tree: dict[str, Any]) -> None: + """ + Update the given symbol tree according to the given import statement. This + can add new symbols to the tree, or change what existing symbols refer to. + + Args: + stmt: A stmt AST node. Node types other than Import and ImportFrom + are ignored. + symbol_tree: A symbol tree structure to update. + """ + if isinstance(stmt, ast_module.Import): + # For a normal import, update the hierarchy according to the + # imported name. If aliased, also introduce an alias symbol at + # the top level. + for alias in stmt.names: + value = _update_symbols(symbol_tree, alias.name) + + if alias.asname: + symbol_tree[alias.asname] = value + + elif isinstance(stmt, ast_module.ImportFrom): + # for mypy: how can a "from import ...", import from nothing? + # But module is apparently optional... + assert stmt.module + + # Can't hope to interpret relative imports by themselves, because + # we don't know what they're relative to. So just ignore those. + # E.g. "from ...foo import bar". + if stmt.level == 0: + # Update the symbol hierarchy with the module name + # (from "..."). This identifies a module to import from. + mod_value = _update_symbols(symbol_tree, stmt.module) + + # Each imported symbol is introduced at the sub-module level + # (from ... import "..."), since the statement implies that + # symbol exists there. If the symbol is not aliased, it is + # also introduced at the top level. If it is aliased, only the + # alias is introduced at the top level. + for alias in stmt.names: + value = mod_value.setdefault(alias.name, {}) + if alias.asname: + symbol_tree[alias.asname] = value + else: + symbol_tree[alias.name] = value + + +def _is_register_decorator(decorator_symbol: str, symbol_tree: dict[str, Any]) -> bool: + """ + Try to detect a pyplugs registration decorator symbol. In dioptra, the + "register" symbol is defined in the "dioptra.pyplugs" module. So one could + import the dioptra.pyplugs module and just access the "register" symbol + from there, or import the "register" symbol directly. E.g. + + import dioptra.pyplugs + + @dioptra.pyplugs.register + def foo(): + pass + + Or: + + from dioptra import pyplugs + + @pyplugs.register + def foo(): + pass + + Or: + + from dioptra.pyplugs import register + + @register + def foo(): + pass + + In the first two cases, our symbol tree would contain "dioptra.pyplugs" + but not "register" since the latter was never mentioned in an import + statement. In the last case, the whole "dioptra.pyplugs.register" symbol + would be present. We need to handle both cases. This should also be + transparent to aliasing, e.g. + + from dioptra import pyplugs as bar + + @bar.register + def foo(): + pass + + must also work. + + Args: + decorator_symbol: A decorator symbol used on a function, as a string, + e.g. "foo", "foo.bar", etc + symbol_tree: A data structure representing symbol hierarchy inferred + from import statements + + Returns: + True if the decorator symbol represents a task plugin registration + decorator; False if not + """ + + if _are_aliases(symbol_tree, "dioptra.pyplugs.register", decorator_symbol): + result = True + + elif decorator_symbol.endswith(".register"): + deco_prefix = decorator_symbol[:-9] + result = _are_aliases(symbol_tree, "dioptra.pyplugs", deco_prefix) + + else: + result = False + + return result + + +def _is_task_plugin( + func_def: ast_module.FunctionDef, symbol_tree: dict[str, Any] +) -> bool: + """ + Determine whether the given function definition is defining a task plugin. + + Args: + func_def: A function definition AST node + symbol_tree: A data structure representing symbol hierarchy inferred + from import statements + + Returns: + True if the function definition is for a task plugin; False if not + """ + for decorator_expr in func_def.decorator_list: + + # we will only handle simple decorator expressions: simple dotted + # names, optionally with a function call. + if _is_simple_dotted_name(decorator_expr): + decorator_symbol = ast_module.unparse(decorator_expr) + + elif isinstance(decorator_expr, ast_module.Call) and _is_simple_dotted_name( + decorator_expr.func + ): + decorator_symbol = ast_module.unparse(decorator_expr.func) + + else: + decorator_symbol = None + + if decorator_symbol and _is_register_decorator(decorator_symbol, symbol_tree): + result = True + break + else: + result = False + + return result + + +def _find_plugins(ast: ast_module.Module) -> Iterator[ast_module.FunctionDef]: + """ + Find AST nodes corresponding to task plugin functions. + + Args: + ast: An AST node. Plugin functions will only be found inside Module + nodes + + Yields: + AST nodes corresponding to task plugin function definitions + """ + if isinstance(ast, ast_module.Module): + symbol_tree: dict[str, Any] = {} + for stmt in ast.body: + + if isinstance(stmt, (ast_module.Import, ast_module.ImportFrom)): + _process_import(stmt, symbol_tree) + + elif isinstance(stmt, ast_module.FunctionDef) and _is_task_plugin( + stmt, symbol_tree + ): + yield stmt + + +def _derive_type_name_from_annotation(annotation_ast: ast_module.AST) -> Optional[str]: + """ + Try to derive a suitable Dioptra type name from a type annotation AST. + Annotations can be arbitrarily complex and even nonsensical (not all + kind of errors are caught at parse time), so derivation may fail depending + on the AST. + + Args: + annotation_ast: An AST used as an argument or return type annotation + + Returns: + A type name if one could be derived, or None if one could not be + derived from the given annotation + """ + + # "None" isn't a type, but is used to mean the type of None + if _is_constant(annotation_ast, None): + type_name_suggestion = "null" + + # A name, e.g. int + elif isinstance(annotation_ast, ast_module.Name): + type_name_suggestion = annotation_ast.id + + # A string literal, e.g. "foo". Can be used in Python code to defer + # evaluation of an annotation. + elif isinstance(annotation_ast, ast_module.Constant) and isinstance( + annotation_ast.value, str + ): + type_name_suggestion = annotation_ast.value + + # Frequently used annotation expressions, e.g. list[str] is a "Subscript", + # and str | int is a "BinOp". + elif isinstance( + annotation_ast, (ast_module.Subscript, ast_module.BinOp) + ) or _is_simple_dotted_name(annotation_ast): + type_name_suggestion = ast_module.unparse(annotation_ast) + + else: + type_name_suggestion = None + + # normalize the suggestion, if we were able to derive one + if type_name_suggestion: + type_name_suggestion = type_name_suggestion.strip() + type_name_suggestion = type_name_suggestion.lower() + type_name_suggestion = type_name_suggestion.replace(" ", "") + # Replace non-alphanumerics with underscores + type_name_suggestion = re.sub(r"\W+", "_", type_name_suggestion) + # Condense multiple underscores to one + type_name_suggestion = re.sub("_+", "_", type_name_suggestion) + type_name_suggestion = type_name_suggestion.strip("_") + + # Try to map to a Dioptra builtin type name. + type_name_suggestion = _PYTHON_TO_DIOPTRA_TYPE_NAME.get( + type_name_suggestion, type_name_suggestion + ) + + # After all this, if we wound up with an empty string, we failed. + # If the name doesn't begin with a letter (like all good identifiers + # should), we also failed. + if not type_name_suggestion or not type_name_suggestion[0].isalpha(): + type_name_suggestion = None + + return type_name_suggestion + + +def _make_unique_type_name(existing_types: Container[str]) -> str: + """ + Make a unique type name, i.e. one which doesn't exist in existing_types. + One never knows if a user's type annotation actually resulted in a derived + type name which matches our chosen unique name syntax. So it is not + sufficient to maintain a counter elsewhere which is incremented every time + we need a new unique name. That might result in name collisions. So this + is done conservatively (if inefficiently) by concatenating a base name with + an incrementing integer counter starting at 1, until we obtain a name which + has not previously been seen. + + :param existing_types: A container of existing type names + :return: A new type name which is not in the container + """ + counter = 1 + type_name = f"type{counter}" + while type_name in existing_types: + counter += 1 + type_name = f"type{counter}" + + return type_name + + +def _pos_args_defaults( + args: ast_module.arguments, +) -> Iterator[tuple[ast_module.arg, Optional[ast_module.expr]]]: + """ + Generate the positional argument AST nodes paired with their defined + default AST nodes (if any), contained within the given AST arguments value. + This requires a bit of coding since pos args/defaults aren't stored in a + way you can straightforwardly just zip them up. This includes all + positional-only and "regular" (non-keyword-only) arguments, in the order + they appear in the function signature. + + Args: + args: An AST arguments value + + Yields: + positional arg, arg default pairs. If an arg does not have a default + defined in the signature, it is generated as None. + """ + num_pos_args = len(args.posonlyargs) + len(args.args) + idx_first_defaulted_arg = num_pos_args - len(args.defaults) + + for arg_idx, arg in enumerate(itertools.chain(args.posonlyargs, args.args)): + if arg_idx >= idx_first_defaulted_arg: + arg_default = args.defaults[arg_idx - idx_first_defaulted_arg] + else: + arg_default = None + + yield arg, arg_default + + +def _func_args_defaults( + func: ast_module.FunctionDef, +) -> Iterator[tuple[ast_module.arg, Optional[ast_module.expr]]]: + """ + Generate all argument AST nodes paired with their defined default AST nodes + (if any). This includes positional-only and keyword-only arguments, in the + order they appear in the function signature. + + Args: + func: A FunctionDef AST node representing a function definition + + Yields: + arg, arg default pairs. If an arg does not have a default defined in + the signature, it is generated as None. + """ + yield from _pos_args_defaults(func.args) + yield from zip(func.args.kwonlyargs, func.args.kw_defaults) + + +def _func_args(func: ast_module.FunctionDef) -> Iterator[ast_module.arg]: + """ + Generate all argument AST nodes. This does not include any of their + defaults. They are generated in the order they appear in the function + signature. + + Args: + func: A FunctionDef AST node representing a function definition + + Returns: + An iterator which produces all function argument AST nodes + """ + # Must use same iteration order as _func_args_defaults()! + return itertools.chain(func.args.posonlyargs, func.args.args, func.args.kwonlyargs) + + +def _get_function_signature_via_derivation( + func: ast_module.FunctionDef, +) -> dict[str, Any]: + """ + Create a dict structure which reflects the signature of the given function, + including where possible, argument and return type names suitable for use + with the Dioptra type system. This function tries to derive type names + from argument/return type annotations. This derivation may or may not + produce a suitable type name. Where it is unable to derive a type name, + None is used in the data structure. The end result is a structure which + accounts for all arguments and the return type, although some type names + may be None. + + Args: + func: A FunctionDef AST node representing a function definition + + Returns: + A function signature data structure as a dict + """ + inputs = [] + outputs = [] + suggested_types = [] + used_type_names = set() + + for arg, arg_default in _func_args_defaults(func): + if arg.annotation: + type_name_suggestion = _derive_type_name_from_annotation(arg.annotation) + else: + type_name_suggestion = None + + inputs.append( + { + "name": arg.arg, + "type": type_name_suggestion, # might be None + "required": arg_default is None, + } + ) + + # Add suggestions for non-Dioptra-builtin types only, which we have not + # already created a suggestion for + if ( + type_name_suggestion + and type_name_suggestion not in type_registry.BUILTIN_TYPES + and type_name_suggestion not in used_type_names + ): + # For mypy: we would not have a type name suggestion here if we did + # not have an annotation. + assert arg.annotation + suggested_types.append( + { + "suggestion": type_name_suggestion, + "type_annotation": ast_module.unparse(arg.annotation), + } + ) + + used_type_names.add(type_name_suggestion) + + # Also address any return annotation other than None. If it is None, + # skip the output. None means the plugin produces no output. + if func.returns and not _is_constant(func.returns, None): + type_name_suggestion = _derive_type_name_from_annotation(func.returns) + + outputs.append( + {"name": "output", "type": type_name_suggestion} # might be None + ) + + if ( + type_name_suggestion + and type_name_suggestion not in type_registry.BUILTIN_TYPES + and type_name_suggestion not in used_type_names + ): + suggested_types.append( + { + "suggestion": type_name_suggestion, + "type_annotation": ast_module.unparse(func.returns), + } + ) + + used_type_names.add(type_name_suggestion) + + signature = { + "name": func.name, + "inputs": inputs, + "outputs": outputs, + "suggested_types": suggested_types, + } + + return signature + + +def _complete_function_signature_via_generation( + func: ast_module.FunctionDef, signature: dict[str, Any] +) -> None: + """ + Search through the given signature structure for missing (None) type names, + and use name generation to generate unique names. The signature structure + is updated such that all arguments and return type should have a type name. + + Args: + func: A FunctionDef AST node representing a function definition + signature: A function signature structure to update + """ + + # Gather used types; use this to ensure uniqueness of generated types. + used_type_names = { + input_["type"] for input_ in signature["inputs"] if input_["type"] + } + + used_type_names.update( + output["type"] for output in signature["outputs"] if output["type"] + ) + + # For annotations for which we could not derive a type name, we must + # nevertheless recognize annotation reuse, and reuse the same + # generated unique type name. I don't think AST's have any support + # for equality checks, hashing, etc. The only way I can think of to + # compare one AST to another is via their unparsed Python code (as + # strings). So this mapping maps unparsed Python to a generated unique + # name. + ann_to_unique: dict[str, str] = {} + unparsed_ann: Optional[str] + + for input_, arg in zip(signature["inputs"], _func_args(func)): + if not input_["type"]: + if arg.annotation: + unparsed_ann = ast_module.unparse(arg.annotation) + type_name_suggestion = ann_to_unique.get(unparsed_ann) + else: + unparsed_ann = type_name_suggestion = None + + if not type_name_suggestion: + type_name_suggestion = _make_unique_type_name(used_type_names) + if unparsed_ann: + ann_to_unique[unparsed_ann] = type_name_suggestion + + input_["type"] = type_name_suggestion + + if unparsed_ann and type_name_suggestion not in used_type_names: + signature["suggested_types"].append( + { + "suggestion": type_name_suggestion, + "type_annotation": unparsed_ann, + } + ) + + used_type_names.add(type_name_suggestion) + + # generate a type name for output if necessary + if signature["outputs"]: + output = signature["outputs"][0] + if not output["type"]: + # For mypy: we would not have a defined output if the function did + # not have a return type annotation. + assert func.returns + unparsed_ann = ast_module.unparse(func.returns) + type_name_suggestion = ann_to_unique.get(unparsed_ann) + if not type_name_suggestion: + type_name_suggestion = _make_unique_type_name(used_type_names) + ann_to_unique[unparsed_ann] = type_name_suggestion + + output["type"] = type_name_suggestion + + if type_name_suggestion not in used_type_names: + signature["suggested_types"].append( + { + "suggestion": type_name_suggestion, + "type_annotation": unparsed_ann, + } + ) + + used_type_names.add(type_name_suggestion) + + +def get_plugin_signatures( + python_source: str, filepath: Optional[Union[str, Path]] = None +) -> Iterator[dict[str, Any]]: + """ + Extract plugin signatures and build signature information structures from + all task plugins defined in the given source code. + + Args: + python_source: Some Python source code; should be complete with + supporting import statements to assist in understanding what + symbols mean + filepath: A value representative of where the python source came from. + This is an optional arg passed on to the underlying compile() + function, which documents it as: "The filename argument should + give the file from which the code was read; pass some recognizable + value if it wasn't read from a file ('' is commonly used)." + + Yields: + Function signature information data structures, as dicts + """ + if filepath: + ast = ast_module.parse( + python_source, filename=filepath, feature_version=sys.version_info[0:2] + ) + else: + ast = ast_module.parse(python_source, feature_version=sys.version_info[0:2]) + + for plugin_func in _find_plugins(ast): + + # We need to come up with a syntax for unique type names. But no + # matter what syntax we choose, a user's type annotations might collide + # with it. So we can't easily do this in one pass where we generate a + # name whenever we fail to derive one from a type annotation. If a + # subsequent type name derived from a user type annotation collides + # with a unique name we already generated, the user's name must take + # precedence. + # + # A better way is to make two passes: the first pass derives type names + # from type annotations where possible, and determines what the + # user-annotation-derived type names are. The second pass uses unique + # name generation to generate all type names we could not derive in the + # first pass, where the generation can use the names derived in the + # first pass to ensure there are no naming collisions. + + # Pass #1 + signature = _get_function_signature_via_derivation(plugin_func) + + # Pass #2 + _complete_function_signature_via_generation(plugin_func, signature) + + yield signature + + +def get_plugin_signatures_from_file( + filepath: Union[str, Path], encoding: str = "utf-8" +) -> Iterator[dict[str, Any]]: + """ + Extract plugin signatures and build signature information structures from + all task plugins defined in the given Python source file. + + Args: + filepath: A path to a file with Python source code; should be complete + with supporting import statements to assist in understanding what + symbols mean + encoding: A text encoding used to read the given source file + + Returns: + An iterator of function signature information data structures, as dicts + """ + filepath = Path(filepath) + python_source = filepath.read_text(encoding=encoding) + + return get_plugin_signatures(python_source, filepath) diff --git a/src/dioptra/restapi/v1/workflows/controller.py b/src/dioptra/restapi/v1/workflows/controller.py index 428619cdc..55024531d 100644 --- a/src/dioptra/restapi/v1/workflows/controller.py +++ b/src/dioptra/restapi/v1/workflows/controller.py @@ -19,14 +19,19 @@ import structlog from flask import request, send_file -from flask_accepts import accepts +from flask_accepts import accepts, responds from flask_login import login_required from flask_restx import Namespace, Resource from injector import inject from structlog.stdlib import BoundLogger -from .schema import FileTypes, JobFilesDownloadQueryParametersSchema -from .service import JobFilesDownloadService +from .schema import ( + FileTypes, + JobFilesDownloadQueryParametersSchema, + SignatureAnalysisOutputSchema, + SignatureAnalysisSchema, +) +from .service import JobFilesDownloadService, SignatureAnalysisService LOGGER: BoundLogger = structlog.stdlib.get_logger() @@ -78,3 +83,35 @@ def get(self): mimetype=mimetype[parsed_query_params["file_type"]], download_name=download_name[parsed_query_params["file_type"]], ) + + +@api.route("/pluginTaskSignatureAnalysis") +class SignatureAnalysisEndpoint(Resource): + @inject + def __init__( + self, signature_analysis_service: SignatureAnalysisService, *args, **kwargs + ) -> None: + """Initialize the workflow resource. + + All arguments are provided via dependency injection. + + Args: + signature_analysis_service: A SignatureAnalysisService object. + """ + self._signature_analysis_service = signature_analysis_service + super().__init__(*args, **kwargs) + + @login_required + @accepts(schema=SignatureAnalysisSchema, api=api) + @responds(schema=SignatureAnalysisOutputSchema, api=api) + def post(self): + """Download a compressed file archive containing the files needed to execute a submitted job.""" # noqa: B950 + log = LOGGER.new( # noqa: F841 + request_id=str(uuid.uuid4()), + resource="SignatureAnalysis", + request_type="POST", + ) + parsed_obj = request.parsed_obj + return self._signature_analysis_service.post( + python_code=parsed_obj["python_code"], + ) diff --git a/src/dioptra/restapi/v1/workflows/schema.py b/src/dioptra/restapi/v1/workflows/schema.py index 92ea28ec7..505d4cdb7 100644 --- a/src/dioptra/restapi/v1/workflows/schema.py +++ b/src/dioptra/restapi/v1/workflows/schema.py @@ -41,3 +41,84 @@ class JobFilesDownloadQueryParametersSchema(Schema): by_value=True, default=FileTypes.TAR_GZ.value, ) + + +class SignatureAnalysisSchema(Schema): + + pythonCode = fields.String( + attribute="python_code", + metadata=dict(description="The contents of the python file"), + ) + + +class SignatureAnalysisSignatureParamSchema(Schema): + name = fields.String( + attribute="name", metadata=dict(description="The name of the parameter") + ) + type = fields.String( + attribute="type", metadata=dict(description="The type of the parameter") + ) + + +class SignatureAnalysisSignatureInputSchema(SignatureAnalysisSignatureParamSchema): + required = fields.Boolean( + attribute="required", + metadata=dict(description="Whether this is a required parameter"), + ) + + +class SignatureAnalysisSignatureOutputSchema(SignatureAnalysisSignatureParamSchema): + pass + + +class SignatureAnalysisSuggestedTypes(Schema): + + # add proposed_type in next iteration + + name = fields.String( + attribute="name", + metadata=dict(description="A suggestion for the name of the type"), + ) + + description = fields.String( + attribute="description", + metadata=dict( + description="The annotation the suggestion is attempting to represent" + ), + ) + + +class SignatureAnalysisSignatureSchema(Schema): + name = fields.String( + attribute="name", metadata=dict(description="The name of the function") + ) + inputs = fields.Nested( + SignatureAnalysisSignatureInputSchema, + metadata=dict(description="A list of objects describing the input parameters."), + many=True, + ) + outputs = fields.Nested( + SignatureAnalysisSignatureOutputSchema, + metadata=dict( + description="A list of objects describing the output parameters." + ), + many=True, + ) + missing_types = fields.Nested( + SignatureAnalysisSuggestedTypes, + metadata=dict( + description="A list of missing types for non-primitives defined by the file" + ), + many=True, + ) + + +class SignatureAnalysisOutputSchema(Schema): + tasks = fields.Nested( + SignatureAnalysisSignatureSchema, + metadata=dict( + description="A list of signature analyses for the plugin tasks " + "provided in the input file" + ), + many=True, + ) diff --git a/src/dioptra/restapi/v1/workflows/service.py b/src/dioptra/restapi/v1/workflows/service.py index d5769e274..b162cb948 100644 --- a/src/dioptra/restapi/v1/workflows/service.py +++ b/src/dioptra/restapi/v1/workflows/service.py @@ -15,11 +15,13 @@ # ACCESS THE FULL CC BY 4.0 LICENSE HERE: # https://creativecommons.org/licenses/by/4.0/legalcode """The server-side functions that perform workflows endpoint operations.""" -from typing import IO, Final +from typing import IO, Any, Final, List import structlog from structlog.stdlib import BoundLogger +from dioptra.restapi.v1.shared.signature_analysis import get_plugin_signatures + from .lib import views from .lib.package_job_files import package_job_files from .schema import FileTypes @@ -65,3 +67,51 @@ def get(self, job_id: int, file_type: FileTypes, **kwargs) -> IO[bytes]: file_type=file_type, logger=log, ) + + +class SignatureAnalysisService(object): + """The service methods for performing signature analysis on a file.""" + + def post(self, python_code: str, **kwargs) -> dict[str, List[dict[str, Any]]]: + """Perform signature analysis on a file. + + Args: + filename: The name of the file. + python_code: The contents of the file. + + Returns: + A dictionary containing the signature analysis. + """ + log: BoundLogger = kwargs.get("log", LOGGER.new()) + log.debug( + "Performing signature analysis", + python_source=python_code, + ) + endpoint_analyses = [ + _create_endpoint_analysis_dict(signature) + for signature in get_plugin_signatures(python_source=python_code) + ] + return {"tasks": endpoint_analyses} + + +def _create_endpoint_analysis_dict( + signature: dict[str, Any], +) -> dict[str, Any]: + """Create an endpoint analysis dictionary from a signature analysis. + Args: + signature: The signature analysis. + Returns: + The endpoint analysis dictionary. + """ + return { + "name": signature["name"], + "inputs": signature["inputs"], + "outputs": signature["outputs"], + "missing_types": [ + { + "description": suggested_type["type_annotation"], + "name": suggested_type["suggestion"], + } + for suggested_type in signature["suggested_types"] + ], + } diff --git a/tests/unit/restapi/test_signature_analysis.py b/tests/unit/restapi/test_signature_analysis.py new file mode 100644 index 000000000..91bbd3810 --- /dev/null +++ b/tests/unit/restapi/test_signature_analysis.py @@ -0,0 +1,340 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.restapi.v1.shared.signature_analysis import get_plugin_signatures + + +def test_plugin_recognition_1(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_2(): + source = """\ +from dioptra import pyplugs + +@pyplugs.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_3(): + source = """\ +from dioptra.pyplugs import register + +@register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_1(): + source = """\ +import dioptra.pyplugs as foo + +@foo.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_2(): + source = """\ +from dioptra import pyplugs as foo + +@foo.register +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_3(): + source = """\ +from dioptra.pyplugs import register as foo + +@foo +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_call(): + source = """\ +from dioptra.pyplugs import register + +@register() +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_alias_call(): + source = """\ +from dioptra.pyplugs import register as foo + +@foo() +def test_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 1 + + +def test_plugin_recognition_none(): + source = """\ +import dioptra.pyplugs + +# missing the decorator +def not_a_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert not signatures + + +def test_plugin_recognition_complex(): + source = """\ +from dioptra.pyplugs import register +import aaa + +@register() +def test_plugin(): + pass + +@aaa.register +def not_a_plugin(): + pass + +class SomeClass: + pass + +def some_other_func(): + pass + +x = 1 + +@register +def test_plugin2(): + pass + +# re-definition of the "register" symbol +from bbb import ccc as register + +@register +def also_not_a_plugin(): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert len(signatures) == 2 + + +def test_dioptra_builtin_types(): + source = """\ +from dioptra.pyplugs import register + +@register +def test_plugin( + arg1: str, + arg2: int, + arg3: float, + arg4: bool, + arg5: None +): + pass +""" + + signatures = list(get_plugin_signatures(source)) + + assert signatures == [ + { + "name": "test_plugin", + "inputs": [ + {"name": "arg1", "required": True, "type": "string"}, + {"name": "arg2", "required": True, "type": "integer"}, + {"name": "arg3", "required": True, "type": "number"}, + {"name": "arg4", "required": True, "type": "boolean"}, + {"name": "arg5", "required": True, "type": "null"}, + ], + "outputs": [], + "suggested_types": [], + } + ] + + +def test_return_none(): + source = """\ +from dioptra.pyplugs import register + +# None is same as not having a return type annotation +@register +def my_plugin() -> None: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + {"name": "my_plugin", "inputs": [], "outputs": [], "suggested_types": []} + ] + + +def test_derive_type_simple(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register() +def the_plugin(arg1: SomeType) -> SomeType: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "the_plugin", + "inputs": [{"name": "arg1", "required": True, "type": "sometype"}], + "outputs": [{"name": "output", "type": "sometype"}], + "suggested_types": [ + {"suggestion": "sometype", "type_annotation": "SomeType"} + ], + } + ] + + +def test_derive_type_complex(): + source = """\ +import dioptra.pyplugs + +@dioptra.pyplugs.register() +def the_plugin(arg1: Optional[str]) -> Union[int, bool]: + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "the_plugin", + "inputs": [{"name": "arg1", "required": True, "type": "optional_str"}], + "outputs": [{"name": "output", "type": "union_int_bool"}], + "suggested_types": [ + {"suggestion": "optional_str", "type_annotation": "Optional[str]"}, + {"suggestion": "union_int_bool", "type_annotation": "Union[int, bool]"}, + ], + } + ] + + +def test_generate_type(): + source = """\ +import dioptra.pyplugs + +# annotation is a function call; we don't attempt a type derivation for +# that kind of annotation. +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2)) -> foo(2): + pass +""" + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "plugin_func", + "inputs": [{"name": "arg1", "required": True, "type": "type1"}], + "outputs": [{"name": "output", "type": "type1"}], + "suggested_types": [{"suggestion": "type1", "type_annotation": "foo(2)"}], + } + ] + + +def test_generate_type_conflict(): + source = """\ +import dioptra.pyplugs + +# annotation is a function call; we don't attempt a type derivation for +# that kind of annotation. Our first generated type would normally be "type1", +# but we can't use that either because the code author already used that! So +# our generated type will have to be "type2". +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): + pass +""" + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "plugin_func", + "inputs": [ + {"name": "arg1", "required": True, "type": "type2"}, + {"name": "arg2", "required": True, "type": "type1"}, + ], + "outputs": [{"name": "output", "type": "type2"}], + "suggested_types": [ + {"suggestion": "type1", "type_annotation": "Type1"}, + {"suggestion": "type2", "type_annotation": "foo(2)"}, + ], + } + ] + + +def test_optional_arg(): + source = """\ +from dioptra import pyplugs + +@pyplugs.register() +def do_things(arg1: Optional[str], arg2: int = 123): + pass +""" + + signatures = list(get_plugin_signatures(source)) + assert signatures == [ + { + "name": "do_things", + "inputs": [ + {"name": "arg1", "required": True, "type": "optional_str"}, + {"name": "arg2", "required": False, "type": "integer"}, + ], + "outputs": [], + "suggested_types": [ + {"suggestion": "optional_str", "type_annotation": "Optional[str]"} + ], + } + ] diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py new file mode 100644 index 000000000..904d2cf65 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs as foo + + +@foo.register +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py new file mode 100644 index 000000000..f2833120a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_complex_type.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register() +def the_plugin(arg1: Optional[str]) -> Union[int, bool]: + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py new file mode 100644 index 000000000..bc3242674 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_function_type.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2)) -> foo(2): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py new file mode 100644 index 000000000..0ed95097e --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_none_return.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.pyplugs import register + + +@register +def my_plugin() -> None: + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py new file mode 100644 index 000000000..ec847c6ea --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_optional.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra import pyplugs + + +@pyplugs.register() +def do_things(arg1: Optional[str], arg2: int = 123): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py new file mode 100644 index 000000000..73ab9039a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_pyplugs_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra import pyplugs as foo + + +@foo.register +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py new file mode 100644 index 000000000..79689c7ef --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_real_world.py @@ -0,0 +1,424 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from __future__ import annotations + +from pathlib import Path +from typing import Any, Callable, Dict, List, Optional, Tuple, Union + +import numpy as np +import pandas as pd +import scipy.stats +import structlog +from structlog.stdlib import BoundLogger +from tensorflow.keras.preprocessing.image import DirectoryIterator + +import mlflow +from dioptra import pyplugs + +from .artifacts_mlflow import ( + download_all_artifacts, + upload_data_frame_artifact, + upload_directory_as_tarball_artifact, +) +from .artifacts_restapi import ( + get_uri_for_model, + get_uris_for_artifacts, + get_uris_for_job, +) +from .artifacts_utils import extract_tarfile, make_directories +from .attacks_fgm import fgm +from .attacks_patch import create_adversarial_patch_dataset, create_adversarial_patches +from .backend_configs_tensorflow import init_tensorflow +from .data_tensorflow import ( + create_image_dataset, + df_to_predictions, + get_n_classes_from_directory_iterator, + predictions_to_df, +) +from .defenses_image_preprocessing import create_defended_dataset +from .estimators_keras_classifiers import init_classifier +from .estimators_methods import fit +from .metrics_distance import get_distance_metric_list +from .metrics_performance import evaluate_metrics_generic, get_performance_metric_list +from .mlflow import add_model_to_registry +from .random_rng import init_rng +from .random_sample import draw_random_integer +from .registry_art import load_wrapped_tensorflow_keras_classifier +from .registry_mlflow import load_tensorflow_keras_classifier +from .tensorflow import ( + evaluate_metrics_tensorflow, + get_model_callbacks, + get_optimizer, + get_performance_metrics, + predict_tensorflow, +) +from .tracking_mlflow import log_metrics, log_parameters, log_tensorflow_keras_estimator + +LOGGER: BoundLogger = structlog.stdlib.get_logger() + + +@pyplugs.register +def load_dataset( + ep_seed: int = 10145783023, + training_dir: str = "/dioptra/data/Mnist/training", + testing_dir: str = "/dioptra/data/Mnist/testing", + subsets: List[str] = ["testing"], + image_size: Tuple[int, int, int] = [28, 28, 1], + rescale: float = 1.0 / 255, + validation_split: Optional[float] = 0.2, + batch_size: int = 32, + label_mode: str = "categorical", + shuffle: bool = False, +) -> DirectoryIterator: + seed, rng = init_rng(ep_seed) + global_seed = draw_random_integer(rng) + dataset_seed = draw_random_integer(rng) + init_tensorflow(global_seed) + log_parameters( + { + "entry_point_seed": ep_seed, + "tensorflow_global_seed": global_seed, + "dataset_seed": dataset_seed, + } + ) + training_dataset = ( + None + if "training" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="training", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + + validation_dataset = ( + None + if "validation" not in subsets + else create_image_dataset( + data_dir=training_dir, + subset="validation", + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + testing_dataset = ( + None + if "testing" not in subsets + else create_image_dataset( + data_dir=testing_dir, + subset=None, + image_size=image_size, + seed=dataset_seed, + rescale=rescale, + validation_split=validation_split, + batch_size=batch_size, + label_mode=label_mode, + shuffle=shuffle, + ) + ) + return training_dataset, validation_dataset, testing_dataset + + +@pyplugs.register +def create_model( + dataset: DirectoryIterator = None, + model_architecture: str = "le_net", + input_shape: Tuple[int, int, int] = [28, 28, 1], + loss: str = "categorical_crossentropy", + learning_rate: float = 0.001, + optimizer: str = "Adam", + metrics_list: List[Dict[str, Any]] = None, +): + n_classes = get_n_classes_from_directory_iterator(dataset) + optim = get_optimizer(optimizer, learning_rate) + perf_metrics = get_performance_metrics(metrics_list) + classifier = init_classifier( + model_architecture, optim, perf_metrics, input_shape, n_classes, loss + ) + return classifier + + +@pyplugs.register +def load_model( + model_name: str | None = None, + model_version: int | None = None, + imagenet_preprocessing: bool = False, + art: bool = False, + image_size: Any = None, + classifier_kwargs: Optional[Dict[str, Any]] = None, +): + uri = get_uri_for_model(model_name, model_version) + if art: + classifier = load_wrapped_tensorflow_keras_classifier( + uri, imagenet_preprocessing, image_size, classifier_kwargs + ) + else: + classifier = load_tensorflow_keras_classifier(uri) + return classifier + + +@pyplugs.register +def train( + estimator: Any, + x: Any = None, + y: Any = None, + callbacks_list: List[Dict[str, Any]] = None, + fit_kwargs: Optional[Dict[str, Any]] = None, +): + fit_kwargs = {} if fit_kwargs is None else fit_kwargs + callbacks = get_model_callbacks(callbacks_list) + fit_kwargs["callbacks"] = callbacks + fit(estimator=estimator, x=x, y=y, fit_kwargs=fit_kwargs) + return estimator + + +@pyplugs.register +def save_artifacts_and_models( + artifacts: List[Dict[str, Any]] = None, models: List[Dict[str, Any]] = None +): + artifacts = [] if artifacts is None else artifacts + models = [] if models is None else models + + for model in models: + log_tensorflow_keras_estimator(model["model"], "model") + add_model_to_registry(model["name"], "model") + for artifact in artifacts: + if artifact["type"] == "tarball": + upload_directory_as_tarball_artifact( + source_dir=artifact["adv_data_dir"], + tarball_filename=artifact["adv_tar_name"], + ) + if artifact["type"] == "dataframe": + upload_data_frame_artifact( + data_frame=artifact["data_frame"], + file_name=artifact["file_name"], + file_format=artifact["file_format"], + file_format_kwargs=artifact["file_format_kwargs"], + ) + + +@pyplugs.register +def load_artifacts_for_job( + job_id: str, files: List[str | Path] = None, extract_files: List[str | Path] = None +): + files = [] if files is None else files + extract_files = [] if extract_files is None else extract_files + files += extract_files # need to download them to be able to extract + + uris = get_uris_for_job(job_id) + paths = download_all_artifacts(uris, files) + for extract in paths: + for ef in extract_files: + if ef.endswith(str(ef)): + extract_tarfile(extract) + return paths + + +@pyplugs.register +def load_artifacts( + artifact_ids: List[int] = None, extract_files: List[str | Path] = None +): + extract_files = [] if extract_files is None else extract_files + artifact_ids = [] if artifact_ids is not None else artifact_ids + uris = get_uris_for_artifacts(artifact_ids) + paths = download_all_artifacts(uris, extract_files) + for extract in paths: + extract_tarfile(extract) + + +@pyplugs.register +def attack_fgm( + dataset: Any, + adv_data_dir: Union[str, Path], + classifier: Any, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + eps: float = 0.3, + eps_step: float = 0.1, + minimal: bool = False, + norm: Union[int, float, str] = np.inf, +): + """generate fgm examples""" + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + fgm_dataset = fgm( + data_flow=dataset, + adv_data_dir=adv_data_dir, + keras_classifier=classifier, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + eps=eps, + eps_step=eps_step, + minimal=minimal, + norm=norm, + ) + return fgm_dataset + + +@pyplugs.register() +def attack_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + model: Any, + patch_target: int, + num_patch: int, + num_patch_samples: int, + rotation_max: float, + scale_min: float, + scale_max: float, + learning_rate: float, + max_iter: int, + patch_shape: Tuple, +): + """generate patches""" + make_directories([adv_data_dir]) + create_adversarial_patches( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + keras_classifier=model, + patch_target=patch_target, + num_patch=num_patch, + num_patch_samples=num_patch_samples, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max, + learning_rate=learning_rate, + max_iter=max_iter, + patch_shape=patch_shape, + ) + + +@pyplugs.register() +def augment_patch( + data_flow: Any, + adv_data_dir: Union[str, Path], + patch_dir: Union[str, Path], + model: Any, + patch_shape: Tuple, + distance_metrics: List[Dict[str, str]], + batch_size: int = 32, + patch_scale: float = 0.4, + rotation_max: float = 22.5, + scale_min: float = 0.1, + scale_max: float = 1.0, +): + """add patches to a dataset""" + make_directories([adv_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + create_adversarial_patch_dataset( + data_flow=data_flow, + adv_data_dir=adv_data_dir, + patch_dir=patch_dir, + keras_classifier=model, + patch_shape=patch_shape, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + patch_scale=patch_scale, + rotation_max=rotation_max, + scale_min=scale_min, + scale_max=scale_max, + ) + + +@pyplugs.register +def model_metrics(classifier: Any, dataset: Any): + metrics = evaluate_metrics_tensorflow(classifier, dataset) + log_metrics(metrics) + return metrics + + +@pyplugs.register +def prediction_metrics( + y_true: np.ndarray, + y_pred: np.ndarray, + metrics_list: List[Dict[str, str]], + func_kwargs: Dict[str, Dict[str, Any]] = None, +): + func_kwargs = {} if func_kwargs is None else func_kwargs + callable_list = get_performance_metric_list(metrics_list) + metrics = evaluate_metrics_generic(y_true, y_pred, callable_list, func_kwargs) + log_metrics(metrics) + return pd.DataFrame(metrics, index=[1]) + + +@pyplugs.register +def augment_data( + dataset: Any, + def_data_dir: Union[str, Path], + image_size: Tuple[int, int, int], + distance_metrics: List[Dict[str, str]], + batch_size: int = 50, + def_type: str = "spatial_smoothing", + defense_kwargs: Optional[Dict[str, Any]] = None, +): + make_directories([def_data_dir]) + distance_metrics_list = get_distance_metric_list(distance_metrics) + defended_dataset = create_defended_dataset( + data_flow=dataset, + def_data_dir=def_data_dir, + image_size=image_size, + distance_metrics_list=distance_metrics_list, + batch_size=batch_size, + def_type=def_type, + defense_kwargs=defense_kwargs, + ) + return defended_dataset + + +@pyplugs.register +def predict( + classifier: Any, + dataset: Any, + show_actual: bool = False, + show_target: bool = False, +): + predictions = predict_tensorflow(classifier, dataset) + df = predictions_to_df( + predictions, dataset, show_actual=show_actual, show_target=show_target + ) + return df + + +@pyplugs.register +def load_predictions( + paths: List[str], + filename: str, + format: str = "csv", + dataset: DirectoryIterator = None, + n_classes: int = -1, +): + loc = None + for m in paths: + if m.endswith(filename): + loc = m + if format == "csv": + df = pd.read_csv(loc) + elif format == "json": + df = pd.read_json(loc) + y_true, y_pred = df_to_predictions(df, dataset, n_classes) + return y_true, y_pred diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py new file mode 100644 index 000000000..8978be0a0 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_redefinition.py @@ -0,0 +1,54 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import aaa + +from dioptra.pyplugs import register + + +@register() +def test_plugin(): + pass + + +@aaa.register +def not_a_plugin(): + pass + + +class SomeClass: + pass + + +def some_other_func(): + pass + + +x = 1 + + +@register +def test_plugin2(): + pass + + +# re-definition of the "register" symbol +from bbb import ccc as register + + +@register +def also_not_a_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py new file mode 100644 index 000000000..b5ab0d362 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_register_alias.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from dioptra.pyplugs import register as foo + + +@foo +def test_plugin(): + pass diff --git a/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py new file mode 100644 index 000000000..0282d7703 --- /dev/null +++ b/tests/unit/restapi/v1/workflows/signature_analysis/sample_test_type_conflict.py @@ -0,0 +1,22 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +import dioptra.pyplugs + + +@dioptra.pyplugs.register +def plugin_func(arg1: foo(2), arg2: Type1) -> foo(2): + pass diff --git a/tests/unit/restapi/v1/workflows/test_signature_analysis.py b/tests/unit/restapi/v1/workflows/test_signature_analysis.py new file mode 100644 index 000000000..e9e43b86a --- /dev/null +++ b/tests/unit/restapi/v1/workflows/test_signature_analysis.py @@ -0,0 +1,522 @@ +# This Software (Dioptra) is being made available as a public service by the +# National Institute of Standards and Technology (NIST), an Agency of the United +# States Department of Commerce. This software was developed in part by employees of +# NIST and in part by NIST contractors. Copyright in portions of this software that +# were developed by NIST contractors has been licensed or assigned to NIST. Pursuant +# to Title 17 United States Code Section 105, works of NIST employees are not +# subject to copyright protection in the United States. However, NIST may hold +# international copyright in software created by its employees and domestic +# copyright (or licensing rights) in portions of software that were assigned or +# licensed to NIST. To the extent that NIST holds copyright in this software, it is +# being made available under the Creative Commons Attribution 4.0 International +# license (CC BY 4.0). The disclaimers of the CC BY 4.0 license apply to all parts +# of the software developed or licensed by NIST. +# +# ACCESS THE FULL CC BY 4.0 LICENSE HERE: +# https://creativecommons.org/licenses/by/4.0/legalcode +from http import HTTPStatus +from pathlib import Path +from typing import Any + +from flask_sqlalchemy import SQLAlchemy + +from dioptra.client.base import DioptraResponseProtocol +from dioptra.client.client import DioptraClient + +expected_outputs = {} + +expected_outputs["sample_test_real_world.py"] = [ + { + "name": "load_dataset", + "inputs": [ + {"name": "ep_seed", "type": "integer", "required": False}, + {"name": "training_dir", "type": "string", "required": False}, + {"name": "testing_dir", "type": "string", "required": False}, + {"name": "subsets", "type": "list_str", "required": False}, + {"name": "image_size", "type": "tuple_int_int_int", "required": False}, + {"name": "rescale", "type": "number", "required": False}, + {"name": "validation_split", "type": "optional_float", "required": False}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "label_mode", "type": "string", "required": False}, + {"name": "shuffle", "type": "boolean", "required": False}, + ], + "outputs": [{"name": "output", "type": "directoryiterator"}], + "missing_types": [ + {"name": "list_str", "description": "List[str]"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + {"name": "optional_float", "description": "Optional[float]"}, + {"name": "directoryiterator", "description": "DirectoryIterator"}, + ], + }, + { + "name": "create_model", + "inputs": [ + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "model_architecture", "type": "string", "required": False}, + {"name": "input_shape", "type": "tuple_int_int_int", "required": False}, + {"name": "loss", "type": "string", "required": False}, + {"name": "learning_rate", "type": "number", "required": False}, + {"name": "optimizer", "type": "string", "required": False}, + {"name": "metrics_list", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "directoryiterator", "description": "DirectoryIterator"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + }, + ], + }, + { + "name": "load_model", + "inputs": [ + {"name": "model_name", "type": "str_none", "required": False}, + {"name": "model_version", "type": "int_none", "required": False}, + {"name": "imagenet_preprocessing", "type": "boolean", "required": False}, + {"name": "art", "type": "boolean", "required": False}, + {"name": "image_size", "type": "any", "required": False}, + { + "name": "classifier_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "str_none", "description": "str | None"}, + {"name": "int_none", "description": "int | None"}, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "train", + "inputs": [ + {"name": "estimator", "type": "any", "required": True}, + {"name": "x", "type": "any", "required": False}, + {"name": "y", "type": "any", "required": False}, + {"name": "callbacks_list", "type": "list_dict_str_any", "required": False}, + {"name": "fit_kwargs", "type": "optional_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + }, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "save_artifacts_and_models", + "inputs": [ + {"name": "artifacts", "type": "list_dict_str_any", "required": False}, + {"name": "models", "type": "list_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + { + "name": "list_dict_str_any", + "description": "List[Dict[str, Any]]", + } + ], + }, + { + "name": "load_artifacts_for_job", + "inputs": [ + {"name": "job_id", "type": "string", "required": True}, + {"name": "files", "type": "list_str_path", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_str_path", "description": "List[str | Path]"} + ], + }, + { + "name": "load_artifacts", + "inputs": [ + {"name": "artifact_ids", "type": "list_int", "required": False}, + {"name": "extract_files", "type": "list_str_path", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_int", "description": "List[int]"}, + {"name": "list_str_path", "description": "List[str | Path]"}, + ], + }, + { + "name": "attack_fgm", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "classifier", "type": "any", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "eps", "type": "number", "required": False}, + {"name": "eps_step", "type": "number", "required": False}, + {"name": "minimal", "type": "boolean", "required": False}, + {"name": "norm", "type": "union_int_float_str", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "union_int_float_str", + "description": "Union[int, float, str]", + }, + ], + }, + { + "name": "attack_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_target", "type": "integer", "required": True}, + {"name": "num_patch", "type": "integer", "required": True}, + {"name": "num_patch_samples", "type": "integer", "required": True}, + {"name": "rotation_max", "type": "number", "required": True}, + {"name": "scale_min", "type": "number", "required": True}, + {"name": "scale_max", "type": "number", "required": True}, + {"name": "learning_rate", "type": "number", "required": True}, + {"name": "max_iter", "type": "integer", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + {"name": "tuple", "description": "Tuple"}, + ], + }, + { + "name": "augment_patch", + "inputs": [ + {"name": "data_flow", "type": "any", "required": True}, + {"name": "adv_data_dir", "type": "union_str_path", "required": True}, + {"name": "patch_dir", "type": "union_str_path", "required": True}, + {"name": "model", "type": "any", "required": True}, + {"name": "patch_shape", "type": "tuple", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "patch_scale", "type": "number", "required": False}, + {"name": "rotation_max", "type": "number", "required": False}, + {"name": "scale_min", "type": "number", "required": False}, + {"name": "scale_max", "type": "number", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + {"name": "tuple", "description": "Tuple"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + ], + }, + { + "name": "model_metrics", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "prediction_metrics", + "inputs": [ + {"name": "y_true", "type": "np_ndarray", "required": True}, + {"name": "y_pred", "type": "np_ndarray", "required": True}, + {"name": "metrics_list", "type": "list_dict_str_str", "required": True}, + {"name": "func_kwargs", "type": "dict_str_dict_str_any", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "np_ndarray", "description": "np.ndarray"}, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "dict_str_dict_str_any", + "description": "Dict[str, Dict[str, Any]]", + }, + ], + }, + { + "name": "augment_data", + "inputs": [ + {"name": "dataset", "type": "any", "required": True}, + {"name": "def_data_dir", "type": "union_str_path", "required": True}, + {"name": "image_size", "type": "tuple_int_int_int", "required": True}, + {"name": "distance_metrics", "type": "list_dict_str_str", "required": True}, + {"name": "batch_size", "type": "integer", "required": False}, + {"name": "def_type", "type": "string", "required": False}, + { + "name": "defense_kwargs", + "type": "optional_dict_str_any", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "union_str_path", "description": "Union[str, Path]"}, + { + "name": "tuple_int_int_int", + "description": "Tuple[int, int, int]", + }, + { + "name": "list_dict_str_str", + "description": "List[Dict[str, str]]", + }, + { + "name": "optional_dict_str_any", + "description": "Optional[Dict[str, Any]]", + }, + ], + }, + { + "name": "predict", + "inputs": [ + {"name": "classifier", "type": "any", "required": True}, + {"name": "dataset", "type": "any", "required": True}, + {"name": "show_actual", "type": "boolean", "required": False}, + {"name": "show_target", "type": "boolean", "required": False}, + ], + "outputs": [], + "missing_types": [], + }, + { + "name": "load_predictions", + "inputs": [ + {"name": "paths", "type": "list_str", "required": True}, + {"name": "filename", "type": "string", "required": True}, + {"name": "format", "type": "string", "required": False}, + {"name": "dataset", "type": "directoryiterator", "required": False}, + {"name": "n_classes", "type": "integer", "required": False}, + ], + "outputs": [], + "missing_types": [ + {"name": "list_str", "description": "List[str]"}, + {"name": "directoryiterator", "description": "DirectoryIterator"}, + ], + }, +] + +expected_outputs["sample_test_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_complex_type.py"] = [ + { + "name": "the_plugin", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "union_int_bool"}], + "missing_types": [ + {"name": "optional_str", "description": "Optional[str]"}, + {"name": "union_int_bool", "description": "Union[int, bool]"}, + ], + } +] + +expected_outputs["sample_test_function_type.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type1", + "required": True, + } + ], + "outputs": [{"name": "output", "type": "type1"}], + "missing_types": [ + {"name": "type1", "description": "foo(2)"}, + ], + } +] + +expected_outputs["sample_test_none_return.py"] = [ + {"name": "my_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_optional.py"] = [ + { + "name": "do_things", + "inputs": [ + { + "name": "arg1", + "type": "optional_str", + "required": True, + }, + { + "name": "arg2", + "type": "integer", + "required": False, + }, + ], + "outputs": [], + "missing_types": [ + {"name": "optional_str", "description": "Optional[str]"}, + ], + } +] + +expected_outputs["sample_test_pyplugs_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_redefinition.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []}, + {"name": "test_plugin2", "inputs": [], "outputs": [], "missing_types": []}, +] + +expected_outputs["sample_test_register_alias.py"] = [ + {"name": "test_plugin", "inputs": [], "outputs": [], "missing_types": []} +] + +expected_outputs["sample_test_type_conflict.py"] = [ + { + "name": "plugin_func", + "inputs": [ + { + "name": "arg1", + "type": "type2", + "required": True, + }, + { + "name": "arg2", + "type": "type1", + "required": True, + }, + ], + "outputs": [{"name": "output", "type": "type2"}], + "missing_types": [ + {"name": "type2", "description": "foo(2)"}, + {"name": "type1", "description": "Type1"}, + ], + } +] + +# -- Assertions ------------------------------------------------------------------------ + + +def assert_signature_analysis_response_matches_expectations( + response: dict[str, Any], expected_contents: dict[str, Any] +) -> None: + """Assert that a job response contents is valid. + + Args: + response: The actual response from the API. + expected_contents: The expected response from the API. + + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response or if the response contents is not + valid. + """ + # Check expected keys + expected_keys = { + "name", + "missing_types", + "outputs", + "inputs", + } + assert set(response.keys()) == expected_keys + + # Check basic response types + assert isinstance(response["name"], str) + assert isinstance(response["outputs"], list) + assert isinstance(response["missing_types"], list) + assert isinstance(response["inputs"], list) + + def sort_by_name(lst, k="name"): + return sorted(lst, key=lambda x: x[k]) + + assert sort_by_name(response["outputs"]) == sort_by_name( + expected_contents["outputs"] + ) + assert sort_by_name(response["inputs"]) == sort_by_name(expected_contents["inputs"]) + assert sort_by_name(response["missing_types"], k="name") == sort_by_name( + expected_contents["missing_types"], k="name" + ) + + +def assert_signature_analysis_responses_matches_expectations( + responses: list[dict[str, Any]], expected_contents: list[dict[str, Any]] +) -> None: + assert len(responses) == len(expected_contents) + for response in responses: + assert_signature_analysis_response_matches_expectations( + response, [a for a in expected_contents if a["name"] == response["name"]][0] + ) + + +def assert_signature_analysis_file_load_and_contents( + dioptra_client: DioptraClient[DioptraResponseProtocol], + filename: str, +): + location = Path("tests/unit/restapi/v1/workflows/signature_analysis") / filename + + with location.open("r") as f: + contents = f.read() + + contents_analysis = dioptra_client.workflows.analyze_plugin_task_signatures( + python_code=contents, + ) + + assert contents_analysis.status_code == HTTPStatus.OK + + + print(contents_analysis.json()) + assert_signature_analysis_responses_matches_expectations( + contents_analysis.json()["tasks"], + expected_contents=expected_outputs[filename], + ) + + +# -- Tests ----------------------------------------------------------------------------- + + +def test_signature_analysis( + dioptra_client: DioptraClient[DioptraResponseProtocol], + db: SQLAlchemy, + auth_account: dict[str, Any], +) -> None: + """ + Test that signature analysis + Args: + client: The Flask test client. + Raises: + AssertionError: If the response status code is not 200 or if the API response + does not match the expected response. + """ + + for fn in expected_outputs: + assert_signature_analysis_file_load_and_contents( + dioptra_client=dioptra_client, filename=fn + )