diff --git a/examples/prefect/runner.py b/examples/prefect/runner.py index df351487..baf2808c 100644 --- a/examples/prefect/runner.py +++ b/examples/prefect/runner.py @@ -1,3 +1,5 @@ +from typing import Any + import numpy as np import training from prefect import flow, get_run_logger, task @@ -11,7 +13,7 @@ class PrefectReporter(reporter.BenchmarkReporter): def __init__(self): self.logger = get_run_logger() - def write(self, record: types.BenchmarkRecord) -> None: + def write(self, record: types.BenchmarkRecord, **kwargs: Any) -> None: self.logger.info(record) diff --git a/src/nnbench/reporter/__init__.py b/src/nnbench/reporter/__init__.py index f4322fe8..f9568f5a 100644 --- a/src/nnbench/reporter/__init__.py +++ b/src/nnbench/reporter/__init__.py @@ -7,9 +7,12 @@ import types from .base import BenchmarkReporter +from .file import FileReporter # internal, mutable -_reporter_registry: dict[str, type[BenchmarkReporter]] = {} +_reporter_registry: dict[str, type[BenchmarkReporter]] = { + "file": FileReporter, +} # external, immutable reporter_registry: types.MappingProxyType[str, type[BenchmarkReporter]] = types.MappingProxyType( diff --git a/src/nnbench/reporter/base.py b/src/nnbench/reporter/base.py index 6bb97fc7..dc328392 100644 --- a/src/nnbench/reporter/base.py +++ b/src/nnbench/reporter/base.py @@ -1,7 +1,7 @@ from __future__ import annotations import re -from typing import Any, Callable, Sequence +from typing import Any, Callable, List, Sequence from tabulate import tabulate @@ -88,13 +88,13 @@ def display( print(tabulate(filtered, headers="keys", tablefmt=self.tablefmt)) - def read(self) -> BenchmarkRecord: + def read(self, **kwargs: Any) -> BenchmarkRecord | List[BenchmarkRecord]: raise NotImplementedError def read_batched(self) -> list[BenchmarkRecord]: raise NotImplementedError - def write(self, record: BenchmarkRecord) -> None: + def write(self, record: BenchmarkRecord, **kwargs: Any) -> None: raise NotImplementedError def write_batched(self, records: Sequence[BenchmarkRecord]) -> None: diff --git a/src/nnbench/reporter/file.py b/src/nnbench/reporter/file.py new file mode 100644 index 00000000..152d1f52 --- /dev/null +++ b/src/nnbench/reporter/file.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +import json +import os +from typing import IO, Any, Callable, List + +from nnbench.reporter.base import BenchmarkReporter +from nnbench.types import BenchmarkRecord + +ser = Callable[[IO, List[BenchmarkRecord], Any], None] +de = Callable[[IO, dict[str, Any]], List[BenchmarkRecord]] + +# A registry of supported file loaders +_file_loaders: dict[str, tuple[ser, de]] = {} + + +# Register file loaders +def register_file_io(serializer: Callable, deserializer: Callable, file_type: str) -> None: + """ + Registers a serializer and deserializer for a file type. + + Args: + ----- + `serializer (Callable):` Defines how records are written to a file. + `deserializer (Callable):` Defines how file contents are converted to `BenchmarkRecord`. + `file_type (str):` File type extension (e.g., ".json", ".yaml"). + """ + _file_loaders[file_type] = (serializer, deserializer) + + +def _get_file_loader(file_type: str) -> tuple[ser, de]: + """Helps retrieve registered file loaders of the given file_type with error handling""" + file_loaders = _file_loaders.get(file_type) + if not file_loaders: + raise ValueError(f"File loaders for `{file_type}` files does not exist") + return file_loaders + + +# json file loader: +def json_load(fp: IO, options: Any = None) -> List[BenchmarkRecord] | None: + file_content = fp.read() + if file_content: + objs = [ + BenchmarkRecord(context=obj["context"], benchmarks=obj["benchmarks"]) + for obj in json.loads(file_content) + ] + return objs + return None + + +def json_save(fp: IO, records: List[BenchmarkRecord], options: Any = None) -> None: + fp.write(json.dumps(records)) + + +# yaml file loader: +def yaml_load(fp: IO, options: Any = None) -> List[BenchmarkRecord] | None: + try: + import yaml + except ImportError: + raise ModuleNotFoundError("`pyyaml` is not installed") + + file_content = fp.read() + if file_content: + objs = [ + BenchmarkRecord(context=obj["context"], benchmarks=obj["benchmarks"]) + for obj in yaml.safe_load(file_content) + ] + return objs + return None + + +def yaml_save(fp: IO, records: List[BenchmarkRecord], options: dict[str, Any] = None) -> None: + try: + import yaml + except ImportError: + raise ModuleNotFoundError("`pyyaml` is not installed") + + # To avoid `yaml.safe_dump()` error when trying to write numpy array + for element in records[-1]["benchmarks"]: + element["value"] = float(element["value"]) + yaml.safe_dump(records, fp, **(options or {})) + + +# Register json and yaml file loaders +register_file_io(json_save, json_load, file_type="json") +register_file_io(yaml_save, yaml_load, file_type="yaml") + + +class FileReporter(BenchmarkReporter): + """ + Reports benchmark results to files in a given directory. + + This class implements a `BenchmarkReporter` subclass that persists benchmark + records to files within a specified directory. It supports both reading and + writing records, using file extensions to automatically determine the appropriate + serialization format. + + Args: + ----- + directory (str): The directory where benchmark files will be stored. + + Raises: + ------- + BaseException: If the directory is not initialized. + """ + + def __init__(self, directory: str): + self.directory = directory + if not os.path.exists(directory): + self.initialize() + + def initialize(self) -> None: + os.makedirs(self.directory, exist_ok=True) + + def read(self, **kwargs: Any) -> List[BenchmarkRecord]: + if not self.directory: + raise BaseException("No directory is initialized") + file_name = str(kwargs["file_name"]) + file_path = os.path.join(self.directory, file_name) + file_type = file_name.split(".")[1] + with open(file_path) as file: + return _get_file_loader(file_type)[1](file, {}) + + def write(self, record: BenchmarkRecord, **kwargs: dict[str, Any]) -> None: + if not self.directory: + raise BaseException("No directory is initialized") + file_name = str(kwargs["file_name"]) + file_path = os.path.join(self.directory, file_name) + # Create the file, if not already existing + if not os.path.exists(file_path): + with open(file_path, "w") as file: + file.write("") + prev_records = self.read(file_name=file_name) + prev_records = prev_records if prev_records else [] + prev_records.append(record) # + file_type = file_name.split(".")[1] + with open(file_path, "w") as file: + _get_file_loader(file_type)[0](file, prev_records, {})