Skip to content

Commit

Permalink
Add Serializable Functionality and GuidanceReport Class for Enhanced …
Browse files Browse the repository at this point in the history
…Results Management (#22)

## Summary

Introduces the `Serializable` class to provide serialization
capabilities for core Pydantic classes, along with the new
`GuidanceReport` class to manage multiple benchmarking reports. The
additions add in native support for loading and saving to disk in json
and yaml formats.

## Details

- **Serializable Class**: Adds the ability to serialize and deserialize
objects to/from YAML and JSON formats, and save/load from files.
- Implements methods for `to_yaml`, `to_json`, `from_yaml`, `from_json`,
`save_file`, and `load_file`.
- Introduces `SerializableFileType` enum to handle supported file types.
    - Includes validation and error handling for file operations.
- **GuidanceReport Class**: Manages guidance reports containing
benchmarking details across multiple runs.
- Inherits from `Serializable` to leverage serialization capabilities.
    - Contains a list of `TextGenerationBenchmarkReport` objects.
- **CLI Integration**: Updates the CLI to use `GuidanceReport` for
saving benchmark reports.
- Adds `-save-path` option for specifying the path to save the report.
- **Tests**: Adds comprehensive unit tests for the new functionality.
    - Tests for `Serializable` class methods.
- Tests for `GuidanceReport` class, including initialization, file
operations, and serialization.

## Test Plan

- **Automation Testing**:
- Added unit tests for `Serializable` class covering YAML and JSON
serialization/deserialization, file saving/loading, and error handling.
- Added unit tests for `GuidanceReport` class covering initialization
and file operations.
  • Loading branch information
markurtz authored Jul 29, 2024
1 parent 75dac35 commit 331273c
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 16 deletions.
5 changes: 5 additions & 0 deletions src/guidellm/core/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .distribution import Distribution
from .report import GuidanceReport
from .request import TextGenerationRequest
from .result import (
RequestConcurrencyMeasurement,
Expand All @@ -7,6 +8,7 @@
TextGenerationError,
TextGenerationResult,
)
from .serializable import Serializable, SerializableFileType

__all__ = [
"Distribution",
Expand All @@ -16,4 +18,7 @@
"TextGenerationBenchmark",
"TextGenerationBenchmarkReport",
"RequestConcurrencyMeasurement",
"Serializable",
"SerializableFileType",
"GuidanceReport",
]
21 changes: 21 additions & 0 deletions src/guidellm/core/report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import List

from pydantic import Field

from guidellm.core.serializable import Serializable
from guidellm.core.result import TextGenerationBenchmarkReport

__all__ = [
"GuidanceReport",
]


class GuidanceReport(Serializable):
"""
A class to manage the guidance reports that include the benchmarking details,
potentially across multiple runs, for saving and loading from disk.
"""

benchmarks: List[TextGenerationBenchmarkReport] = Field(
default_factory=list, description="The list of benchmark reports."
)
112 changes: 111 additions & 1 deletion src/guidellm/core/serializable.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,24 @@
from typing import Any
from typing import Any, Optional

import os
import yaml
from loguru import logger
from pydantic import BaseModel, ConfigDict
from enum import Enum

from guidellm.utils import is_file_name


__all__ = ["Serializable", "SerializableFileType"]


class SerializableFileType(Enum):
"""
Enum class for file types supported by Serializable.
"""

YAML = "yaml"
JSON = "json"


class Serializable(BaseModel):
Expand Down Expand Up @@ -73,3 +89,97 @@ def from_json(cls, data: str):
obj = cls.model_validate_json(data)

return obj

def save_file(self, path: str, type_: Optional[SerializableFileType] = None) -> str:
"""
Save the model to a file in either YAML or JSON format.
:param path: Path to the exact file or the containing directory.
If it is a directory, the file name will be inferred from the class name.
:param type_: Optional type to save ('yaml' or 'json').
If not provided and the path has an extension,
it will be inferred to save in that format.
If not provided and the path does not have an extension,
it will save in YAML format.
:return: The path to the saved file.
"""
logger.debug("Saving to file... {} with format: {}", path, type_)

if not is_file_name(path):
file_name = f"{self.__class__.__name__.lower()}"
if type_:
file_name += f".{type_.value.lower()}"
else:
file_name += ".yaml"
type_ = SerializableFileType.YAML
path = os.path.join(path, file_name)

if not type_:
extension = path.split(".")[-1].upper()

if extension not in SerializableFileType.__members__:
raise ValueError(
f"Unsupported file extension: {extension}. "
f"Expected one of {', '.join(SerializableFileType.__members__)}) "
f"for {path}"
)

type_ = SerializableFileType[extension]

if type_.name not in SerializableFileType.__members__:
raise ValueError(
f"Unsupported file format: {type_} "
f"(expected 'yaml' or 'json') for {path}"
)

os.makedirs(os.path.dirname(path), exist_ok=True)

with open(path, "w") as file:
if type_ == SerializableFileType.YAML:
file.write(self.to_yaml())
elif type_ == SerializableFileType.JSON:
file.write(self.to_json())
else:
raise ValueError(f"Unsupported file format: {type_}")

logger.info("Successfully saved {} to {}", self.__class__.__name__, path)

return path

@classmethod
def load_file(cls, path: str):
"""
Load a model from a file in either YAML or JSON format.
:param path: Path to the file.
:return: An instance of the model.
"""
logger.debug("Loading from file... {}", path)

if not os.path.exists(path):
raise FileNotFoundError(f"File not found: {path}")
elif not os.path.isfile(path):
raise ValueError(f"Path is not a file: {path}")

extension = path.split(".")[-1].upper()

if extension not in SerializableFileType.__members__:
raise ValueError(
f"Unsupported file extension: {extension}. "
f"Expected one of {', '.join(SerializableFileType.__members__)}) "
f"for {path}"
)

type_ = SerializableFileType[extension]

with open(path, "r") as file:
data = file.read()

if type_ == SerializableFileType.YAML:
obj = cls.from_yaml(data)
elif type_ == SerializableFileType.JSON:
obj = cls.from_json(data)
else:
raise ValueError(f"Unsupported file format: {type_}")

return obj
25 changes: 13 additions & 12 deletions src/guidellm/main.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import click

from guidellm.backend import Backend
from guidellm.core import TextGenerationBenchmarkReport
from guidellm.core import GuidanceReport
from guidellm.executor import (
Executor,
rate_type_to_load_gen_mode,
Expand Down Expand Up @@ -65,6 +65,12 @@
default=None,
help="Number of requests to send for each rate",
)
@click.option(
"--output-path",
type=str,
default="benchmark_report.json",
help="Path to save benchmark report to",
)
def main(
target,
host,
Expand All @@ -80,6 +86,7 @@ def main(
rate,
num_seconds,
num_requests,
output_path,
):
# Create backend
Backend.create(
Expand Down Expand Up @@ -127,18 +134,12 @@ def main(
report = executor.run()

# Save or print results
save_report(report, "benchmark_report.json")
print_report(report)


def save_report(report: TextGenerationBenchmarkReport, filename: str):
with open(filename, "w") as file:
file.write(report.to_json())

guidance_report = GuidanceReport()
guidance_report.benchmarks.append(report)
guidance_report.save_file(output_path)

def print_report(report: TextGenerationBenchmarkReport):
for benchmark in report.benchmarks:
print(f"Rate: {benchmark.completed_request_rate}, Results: {benchmark.results}")
print("Guidance Report Complete:")
print(guidance_report)


if __name__ == "__main__":
Expand Down
9 changes: 8 additions & 1 deletion src/guidellm/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,5 +3,12 @@
PREFERRED_DATA_SPLITS,
STANDARD_SLEEP_INTERVAL,
)
from .functions import is_file_name, is_directory_name

__all__ = ["PREFERRED_DATA_COLUMNS", "PREFERRED_DATA_SPLITS", "STANDARD_SLEEP_INTERVAL"]
__all__ = [
"PREFERRED_DATA_COLUMNS",
"PREFERRED_DATA_SPLITS",
"STANDARD_SLEEP_INTERVAL",
"is_file_name",
"is_directory_name",
]
33 changes: 33 additions & 0 deletions src/guidellm/utils/functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import os


__all__ = [
"is_file_name",
"is_directory_name",
]


def is_file_name(path: str) -> bool:
"""
Check if the path has an extension and is not a directory.
:param path: The path to check.
:type path: str
:return: True if the path is a file naming convention.
"""

_, ext = os.path.splitext(path)

return bool(ext) and not path.endswith(os.path.sep)


def is_directory_name(path: str) -> bool:
"""
Check if the path does not have an extension and is a directory.
:param path: The path to check.
:type path: str
:return: True if the path is a directory naming convention.
"""
_, ext = os.path.splitext(path)
return not ext or path.endswith(os.path.sep)
86 changes: 86 additions & 0 deletions tests/unit/core/test_report.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import pytest
import os
import tempfile
from guidellm.core import (
TextGenerationBenchmark,
TextGenerationBenchmarkReport,
TextGenerationResult,
TextGenerationRequest,
TextGenerationError,
Distribution,
GuidanceReport,
)


@pytest.fixture
def sample_benchmark_report() -> TextGenerationBenchmarkReport:
sample_request = TextGenerationRequest(prompt="sample prompt")
sample_distribution = Distribution()
sample_result = TextGenerationResult(
request=sample_request,
prompt="sample prompt",
prompt_word_count=2,
prompt_token_count=2,
output="sample output",
output_word_count=2,
output_token_count=2,
last_time=None,
first_token_set=False,
start_time=None,
end_time=None,
first_token_time=None,
decode_times=sample_distribution,
)
sample_error = TextGenerationError(request=sample_request, message="sample error")
sample_benchmark = TextGenerationBenchmark(
mode="async",
rate=1.0,
results=[sample_result],
errors=[sample_error],
concurrencies=[],
)
return TextGenerationBenchmarkReport(
benchmarks=[sample_benchmark], args=[{"arg1": "value1"}]
)


def compare_guidance_reports(report1: GuidanceReport, report2: GuidanceReport) -> bool:
return report1 == report2


@pytest.mark.smoke
def test_guidance_report_initialization():
report = GuidanceReport()
assert report.benchmarks == []


@pytest.mark.smoke
def test_guidance_report_initialization_with_params(sample_benchmark_report):
report = GuidanceReport(benchmarks=[sample_benchmark_report])
assert report.benchmarks == [sample_benchmark_report]


@pytest.mark.smoke
def test_guidance_report_file(sample_benchmark_report):
report = GuidanceReport(benchmarks=[sample_benchmark_report])
with tempfile.TemporaryDirectory() as temp_dir:
file_path = os.path.join(temp_dir, "report.yaml")
report.save_file(file_path)
loaded_report = GuidanceReport.load_file(file_path)
assert compare_guidance_reports(report, loaded_report)


@pytest.mark.regression
def test_guidance_report_json(sample_benchmark_report):
report = GuidanceReport(benchmarks=[sample_benchmark_report])
json_str = report.to_json()
loaded_report = GuidanceReport.from_json(json_str)
assert compare_guidance_reports(report, loaded_report)


@pytest.mark.regression
def test_guidance_report_yaml(sample_benchmark_report):
report = GuidanceReport(benchmarks=[sample_benchmark_report])
yaml_str = report.to_yaml()
loaded_report = GuidanceReport.from_yaml(yaml_str)
assert compare_guidance_reports(report, loaded_report)
Loading

0 comments on commit 331273c

Please sign in to comment.