diff --git a/src/guidellm/core/__init__.py b/src/guidellm/core/__init__.py index 79b93a3..bbb7803 100644 --- a/src/guidellm/core/__init__.py +++ b/src/guidellm/core/__init__.py @@ -1,4 +1,5 @@ from .distribution import Distribution +from .report import GuidanceReport from .request import TextGenerationRequest from .result import ( RequestConcurrencyMeasurement, @@ -7,6 +8,7 @@ TextGenerationError, TextGenerationResult, ) +from .serializable import Serializable, SerializableFileType __all__ = [ "Distribution", @@ -16,4 +18,7 @@ "TextGenerationBenchmark", "TextGenerationBenchmarkReport", "RequestConcurrencyMeasurement", + "Serializable", + "SerializableFileType", + "GuidanceReport", ] diff --git a/src/guidellm/core/report.py b/src/guidellm/core/report.py new file mode 100644 index 0000000..afd54d7 --- /dev/null +++ b/src/guidellm/core/report.py @@ -0,0 +1,21 @@ +from typing import List + +from pydantic import Field + +from guidellm.core.serializable import Serializable +from guidellm.core.result import TextGenerationBenchmarkReport + +__all__ = [ + "GuidanceReport", +] + + +class GuidanceReport(Serializable): + """ + A class to manage the guidance reports that include the benchmarking details, + potentially across multiple runs, for saving and loading from disk. + """ + + benchmarks: List[TextGenerationBenchmarkReport] = Field( + default_factory=list, description="The list of benchmark reports." + ) diff --git a/src/guidellm/core/serializable.py b/src/guidellm/core/serializable.py index 7749026..68c7c70 100644 --- a/src/guidellm/core/serializable.py +++ b/src/guidellm/core/serializable.py @@ -1,8 +1,24 @@ -from typing import Any +from typing import Any, Optional +import os import yaml from loguru import logger from pydantic import BaseModel, ConfigDict +from enum import Enum + +from guidellm.utils import is_file_name + + +__all__ = ["Serializable", "SerializableFileType"] + + +class SerializableFileType(Enum): + """ + Enum class for file types supported by Serializable. + """ + + YAML = "yaml" + JSON = "json" class Serializable(BaseModel): @@ -73,3 +89,97 @@ def from_json(cls, data: str): obj = cls.model_validate_json(data) return obj + + def save_file(self, path: str, type_: Optional[SerializableFileType] = None) -> str: + """ + Save the model to a file in either YAML or JSON format. + + :param path: Path to the exact file or the containing directory. + If it is a directory, the file name will be inferred from the class name. + :param type_: Optional type to save ('yaml' or 'json'). + If not provided and the path has an extension, + it will be inferred to save in that format. + If not provided and the path does not have an extension, + it will save in YAML format. + :return: The path to the saved file. + """ + logger.debug("Saving to file... {} with format: {}", path, type_) + + if not is_file_name(path): + file_name = f"{self.__class__.__name__.lower()}" + if type_: + file_name += f".{type_.value.lower()}" + else: + file_name += ".yaml" + type_ = SerializableFileType.YAML + path = os.path.join(path, file_name) + + if not type_: + extension = path.split(".")[-1].upper() + + if extension not in SerializableFileType.__members__: + raise ValueError( + f"Unsupported file extension: {extension}. " + f"Expected one of {', '.join(SerializableFileType.__members__)}) " + f"for {path}" + ) + + type_ = SerializableFileType[extension] + + if type_.name not in SerializableFileType.__members__: + raise ValueError( + f"Unsupported file format: {type_} " + f"(expected 'yaml' or 'json') for {path}" + ) + + os.makedirs(os.path.dirname(path), exist_ok=True) + + with open(path, "w") as file: + if type_ == SerializableFileType.YAML: + file.write(self.to_yaml()) + elif type_ == SerializableFileType.JSON: + file.write(self.to_json()) + else: + raise ValueError(f"Unsupported file format: {type_}") + + logger.info("Successfully saved {} to {}", self.__class__.__name__, path) + + return path + + @classmethod + def load_file(cls, path: str): + """ + Load a model from a file in either YAML or JSON format. + + :param path: Path to the file. + :return: An instance of the model. + """ + logger.debug("Loading from file... {}", path) + + if not os.path.exists(path): + raise FileNotFoundError(f"File not found: {path}") + elif not os.path.isfile(path): + raise ValueError(f"Path is not a file: {path}") + + extension = path.split(".")[-1].upper() + + if extension not in SerializableFileType.__members__: + raise ValueError( + f"Unsupported file extension: {extension}. " + f"Expected one of {', '.join(SerializableFileType.__members__)}) " + f"for {path}" + ) + + type_ = SerializableFileType[extension] + + with open(path, "r") as file: + data = file.read() + + if type_ == SerializableFileType.YAML: + obj = cls.from_yaml(data) + elif type_ == SerializableFileType.JSON: + obj = cls.from_json(data) + else: + raise ValueError(f"Unsupported file format: {type_}") + + return obj diff --git a/src/guidellm/main.py b/src/guidellm/main.py index ad4a674..da6a8a3 100644 --- a/src/guidellm/main.py +++ b/src/guidellm/main.py @@ -1,7 +1,7 @@ import click from guidellm.backend import Backend -from guidellm.core import TextGenerationBenchmarkReport +from guidellm.core import GuidanceReport from guidellm.executor import ( Executor, rate_type_to_load_gen_mode, @@ -65,6 +65,12 @@ default=None, help="Number of requests to send for each rate", ) +@click.option( + "--output-path", + type=str, + default="benchmark_report.json", + help="Path to save benchmark report to", +) def main( target, host, @@ -80,6 +86,7 @@ def main( rate, num_seconds, num_requests, + output_path, ): # Create backend Backend.create( @@ -127,18 +134,12 @@ def main( report = executor.run() # Save or print results - save_report(report, "benchmark_report.json") - print_report(report) - - -def save_report(report: TextGenerationBenchmarkReport, filename: str): - with open(filename, "w") as file: - file.write(report.to_json()) - + guidance_report = GuidanceReport() + guidance_report.benchmarks.append(report) + guidance_report.save_file(output_path) -def print_report(report: TextGenerationBenchmarkReport): - for benchmark in report.benchmarks: - print(f"Rate: {benchmark.completed_request_rate}, Results: {benchmark.results}") + print("Guidance Report Complete:") + print(guidance_report) if __name__ == "__main__": diff --git a/src/guidellm/utils/__init__.py b/src/guidellm/utils/__init__.py index fefad35..3b307d7 100644 --- a/src/guidellm/utils/__init__.py +++ b/src/guidellm/utils/__init__.py @@ -3,5 +3,12 @@ PREFERRED_DATA_SPLITS, STANDARD_SLEEP_INTERVAL, ) +from .functions import is_file_name, is_directory_name -__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"] +__all__ = [ + "PREFERRED_DATA_COLUMNS", + "PREFERRED_DATA_SPLITS", + "STANDARD_SLEEP_INTERVAL", + "is_file_name", + "is_directory_name", +] diff --git a/src/guidellm/utils/functions.py b/src/guidellm/utils/functions.py new file mode 100644 index 0000000..d190d8a --- /dev/null +++ b/src/guidellm/utils/functions.py @@ -0,0 +1,33 @@ +import os + + +__all__ = [ + "is_file_name", + "is_directory_name", +] + + +def is_file_name(path: str) -> bool: + """ + Check if the path has an extension and is not a directory. + + :param path: The path to check. + :type path: str + :return: True if the path is a file naming convention. + """ + + _, ext = os.path.splitext(path) + + return bool(ext) and not path.endswith(os.path.sep) + + +def is_directory_name(path: str) -> bool: + """ + Check if the path does not have an extension and is a directory. + + :param path: The path to check. + :type path: str + :return: True if the path is a directory naming convention. + """ + _, ext = os.path.splitext(path) + return not ext or path.endswith(os.path.sep) diff --git a/tests/unit/core/test_report.py b/tests/unit/core/test_report.py new file mode 100644 index 0000000..1e9f52b --- /dev/null +++ b/tests/unit/core/test_report.py @@ -0,0 +1,86 @@ +import pytest +import os +import tempfile +from guidellm.core import ( + TextGenerationBenchmark, + TextGenerationBenchmarkReport, + TextGenerationResult, + TextGenerationRequest, + TextGenerationError, + Distribution, + GuidanceReport, +) + + +@pytest.fixture +def sample_benchmark_report() -> TextGenerationBenchmarkReport: + sample_request = TextGenerationRequest(prompt="sample prompt") + sample_distribution = Distribution() + sample_result = TextGenerationResult( + request=sample_request, + prompt="sample prompt", + prompt_word_count=2, + prompt_token_count=2, + output="sample output", + output_word_count=2, + output_token_count=2, + last_time=None, + first_token_set=False, + start_time=None, + end_time=None, + first_token_time=None, + decode_times=sample_distribution, + ) + sample_error = TextGenerationError(request=sample_request, message="sample error") + sample_benchmark = TextGenerationBenchmark( + mode="async", + rate=1.0, + results=[sample_result], + errors=[sample_error], + concurrencies=[], + ) + return TextGenerationBenchmarkReport( + benchmarks=[sample_benchmark], args=[{"arg1": "value1"}] + ) + + +def compare_guidance_reports(report1: GuidanceReport, report2: GuidanceReport) -> bool: + return report1 == report2 + + +@pytest.mark.smoke +def test_guidance_report_initialization(): + report = GuidanceReport() + assert report.benchmarks == [] + + +@pytest.mark.smoke +def test_guidance_report_initialization_with_params(sample_benchmark_report): + report = GuidanceReport(benchmarks=[sample_benchmark_report]) + assert report.benchmarks == [sample_benchmark_report] + + +@pytest.mark.smoke +def test_guidance_report_file(sample_benchmark_report): + report = GuidanceReport(benchmarks=[sample_benchmark_report]) + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "report.yaml") + report.save_file(file_path) + loaded_report = GuidanceReport.load_file(file_path) + assert compare_guidance_reports(report, loaded_report) + + +@pytest.mark.regression +def test_guidance_report_json(sample_benchmark_report): + report = GuidanceReport(benchmarks=[sample_benchmark_report]) + json_str = report.to_json() + loaded_report = GuidanceReport.from_json(json_str) + assert compare_guidance_reports(report, loaded_report) + + +@pytest.mark.regression +def test_guidance_report_yaml(sample_benchmark_report): + report = GuidanceReport(benchmarks=[sample_benchmark_report]) + yaml_str = report.to_yaml() + loaded_report = GuidanceReport.from_yaml(yaml_str) + assert compare_guidance_reports(report, loaded_report) diff --git a/tests/unit/core/test_serializable.py b/tests/unit/core/test_serializable.py index bd233e2..0af53ee 100644 --- a/tests/unit/core/test_serializable.py +++ b/tests/unit/core/test_serializable.py @@ -1,6 +1,7 @@ +import os +import tempfile import pytest - -from guidellm.core.serializable import Serializable +from guidellm.core.serializable import Serializable, SerializableFileType class ExampleModel(Serializable): @@ -38,3 +39,107 @@ def test_serializable_from_yaml(): example = ExampleModel.from_yaml(yaml_str) assert example.name == "John Doe" assert example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_json(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "example.json") + saved_path = example.save_file(file_path, SerializableFileType.JSON) + assert os.path.exists(saved_path) + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_yaml(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "example.yaml") + saved_path = example.save_file(file_path, SerializableFileType.YAML) + assert os.path.exists(saved_path) + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_without_extension(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + saved_path = example.save_file(temp_dir) + assert os.path.exists(saved_path) + assert saved_path.endswith(".yaml") + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_with_directory_json(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + saved_path = example.save_file(temp_dir, SerializableFileType.JSON) + assert os.path.exists(saved_path) + assert saved_path.endswith(".json") + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_with_directory_yaml(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + saved_path = example.save_file(temp_dir, SerializableFileType.YAML) + assert os.path.exists(saved_path) + assert saved_path.endswith(".yaml") + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_save_file_invalid_extension(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + invalid_file_path = os.path.join(temp_dir, "example.txt") + with pytest.raises(ValueError, match="Unsupported file extension.*"): + example.save_file(invalid_file_path) + + +@pytest.mark.smoke +def test_serializable_load_file_invalid_extension(): + with tempfile.TemporaryDirectory() as temp_dir: + invalid_file_path = os.path.join(temp_dir, "example.txt") + with open(invalid_file_path, "w") as file: + file.write("invalid content") + with pytest.raises(ValueError, match="Unsupported file extension: TXT"): + ExampleModel.load_file(invalid_file_path) + + +@pytest.mark.smoke +def test_serializable_file_no_type_provided(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + file_path = os.path.join(temp_dir, "example") + saved_path = example.save_file(file_path) + assert os.path.exists(saved_path) + assert saved_path.endswith(".yaml") + loaded_example = ExampleModel.load_file(saved_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30 + + +@pytest.mark.smoke +def test_serializable_file_infer_extension(): + example = ExampleModel(name="John Doe", age=30) + with tempfile.TemporaryDirectory() as temp_dir: + inferred_path = example.save_file(temp_dir, SerializableFileType.JSON) + assert os.path.exists(inferred_path) + assert inferred_path.endswith(".json") + loaded_example = ExampleModel.load_file(inferred_path) + assert loaded_example.name == "John Doe" + assert loaded_example.age == 30