Skip to content

Commit

Permalink
Introduce data loading utility for reading from local cache or downlo…
Browse files Browse the repository at this point in the history
…ading from external URL (#3282)

Summary:
Pull Request resolved: #3282

## Context

Our preprocessed and compressed derivatives of open-source benchmarking datasets (e.g., LCBench) are currently hosted in Manifold blob storage, which limits their accessibility in our open-source software (OSS). To address this, we need to remove the dependency on Manifold.

## Changes
This diff introduces a data download utility that enables loading Pandas DataFrames (stored in a compressed parquet format) from local disk or downloading it from an external URL source if not found. The key changes include:
- Introduced AbstractParquetDataLoader class, providing a way to load parquet data from a cache on local disk or download from an external URL.
- Implemented methods for:
  * Getting the cache path
  * Checking if the data is cached
  * Reading the data from the cache
  * Downloading from an external URL and caching the data
- Added abstract properties for getting the directory name and URL of the cached file, allowing easy specialization for other benchmark datasets.

With these changes, we can now make our LCBench surrograte benchmark problems accessible in OSS and move from `ax.fb` to `ax`.

## WIP/TODO

1. Add new unit tests
2. Address OSS coverage requirements

Reviewed By: esantorella

Differential Revision: D68790695
  • Loading branch information
ltiao authored and facebook-github-bot committed Feb 3, 2025
1 parent 257af9c commit 981c50b
Show file tree
Hide file tree
Showing 33 changed files with 1,110 additions and 1 deletion.
143 changes: 143 additions & 0 deletions ax/benchmark/problems/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
from abc import ABC, abstractmethod
from pathlib import Path

import pandas as pd


class AbstractParquetDataLoader(ABC):
def __init__(
self,
benchmark_name: str,
dataset_name: str,
stem: str,
cache_dir: Path | None = None,
) -> None:
"""
Initialize the ParquetDataLoader.
This class provides a way to load Parquet data from an external URL,
caching it locally to avoid repeated downloads.
It downloads the file from the external URL and saves it to the cache
if it's not already cached, and reads from the cache otherwise.
Args:
dataset_name (str): The name of the dataset to load.
stem (str): The stem of the parquet file.
cache_dir (Path): The directory where cached data will be stored.
Defaults to '~/.cache/ax_benchmark_data'.
"""
self.cache_dir: Path = (
cache_dir
if cache_dir is not None
else Path("~/.cache").expanduser().joinpath("ax_benchmark_data")
)
self.benchmark_name = benchmark_name
self.dataset_name = dataset_name
self.stem = stem

@property
def filename(self) -> str:
"""
Get the filename of the cached file.
This method returns the filename of the cached file, which is the stem
followed by the extension '.parquet.gzip'.
Returns:
str: The filename of the cached file.
"""
return f"{self.stem}.parquet.gzip"

@property
def cache_path(self) -> Path:
"""
Get the path to the cached file.
This method returns the path where the cached file should be stored.
Returns:
Path: The path to the cached file.
"""
return self.cache_dir.joinpath(
self.benchmark_name,
self.dataset_name,
self.filename,
)

def is_cached(self) -> bool:
"""
Check if the data is already cached (whether the file simply exists).
Returns:
bool: True if the data is cached, False otherwise.
"""
return self.cache_path.exists()

def load(self, download: bool = True) -> pd.DataFrame:
"""
Read the parquet data from the cache or download it from the URL.
If the data is cached, this method reads the data from the cache.
If the data is not cached and download is True, this method downloads
the data from the URL, caches it, and then returns the data.
If the data is not cached and download is False, this method raises an OSError.
Args:
download (bool): Whether to download the data if it's not available
locally. If False, this method raises an OSError. Defaults to True.
Returns:
pd.DataFrame: The loaded parquet data.
"""
if self.is_cached():
with self.cache_path.open("rb") as infile:
return pd.read_parquet(infile, engine="pyarrow")
if download:
if self.url is None:
raise ValueError(
f"File {self.cache_path} does not exist, "
"`download` is True, but URL is not specified."
)
return self._fetch_and_cache()
raise ValueError(
f"File {self.cache_path} does not exist and `download` is False"
)

def _fetch_and_cache(self) -> pd.DataFrame:
"""
Download the data from the URL and cache it.
This method downloads the data from the URL, creates the cache directory
if needed, and saves the data to the cache.
Returns:
pd.DataFrame: The downloaded parquet data.
"""
# Download the data from the URL
data = pd.read_parquet(self.url, engine="pyarrow")
# Create the cache directory if needed
self.cache_path.parent.mkdir(parents=True, exist_ok=True)
with self.cache_path.open("wb") as outfile:
data.to_parquet(outfile, engine="pyarrow", compression="gzip")
return data

@property
@abstractmethod
def url(self) -> str | None:
"""
Get the URL of the parquet file.
This method should return the URL of the parquet file to download.
None is allowed to support cases where the user manually populates the
download cache beforehand.
Returns:
str | None: The URL of the parquet file or None.
"""
pass
6 changes: 6 additions & 0 deletions ax/benchmark/problems/surrogate/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
6 changes: 6 additions & 0 deletions ax/benchmark/problems/surrogate/lcbench/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict
206 changes: 206 additions & 0 deletions ax/benchmark/problems/surrogate/lcbench/data.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

# pyre-strict

from collections.abc import Collection
from dataclasses import dataclass, field, InitVar
from pathlib import Path

import pandas as pd

import torch
from ax.benchmark.problems.data import AbstractParquetDataLoader
from ax.benchmark.problems.surrogate.lcbench.utils import (
DEFAULT_METRIC_NAME,
get_lcbench_log_scale_parameter_names,
get_lcbench_parameter_names,
)

DATASET_NAMES = [
"APSFailure",
"Amazon_employee_access",
"Australian",
"Fashion-MNIST",
"KDDCup09_appetency",
"MiniBooNE",
"adult",
"airlines",
"albert",
"bank-marketing",
"blood-transfusion-service-center",
"car",
"christine",
"cnae-9",
"connect-4",
"covertype",
"credit-g",
"dionis",
"fabert",
"helena",
"higgs",
"jannis",
"jasmine",
"jungle_chess_2pcs_raw_endgame_complete",
"kc1",
"kr-vs-kp",
"mfeat-factors",
"nomao",
"numerai28.6",
"phoneme",
"segment",
"shuttle",
"sylvine",
"vehicle",
"volkert",
]


class LCBenchDataLoader(AbstractParquetDataLoader):
def __init__(
self,
dataset_name: str,
stem: str,
cache_dir: Path | None = None,
) -> None:
super().__init__(
benchmark_name="LCBenchLite",
dataset_name=dataset_name,
stem=stem,
cache_dir=cache_dir,
)

@property
def url(self) -> str:
"""
URL to the GZIP compressed parquet files for the 35 datasets from LCBench.
These files were created by splitting the massive JSON dump of LCBench into
datasets, then further into config info, learning curve metrics, and final
results, and subsequently saving them to an efficient Parquet format,
compressed with GZIP, and finally uploading them to address.
"""

return (
"https://raw.githubusercontent.com/ltiao/"
f"{self.benchmark_name}/main/{self.dataset_name}/{self.filename}"
)


@dataclass(kw_only=True)
class LCBenchData:
parameter_df: pd.DataFrame
metric_series: pd.Series
timestamp_series: pd.Series

runtime_series: pd.Series = field(init=False)
# pyre-ignore [16]: Pyre doesn't understand InitVars.
runtime_fillna: InitVar[bool] = False
# pyre-ignore [16]: Pyre doesn't understand InitVars.
log_scale_parameter_names: InitVar[Collection[str] | None] = None
dtype: torch.dtype = torch.double
device: torch.device | None = None

def __post_init__(
self,
runtime_fillna: bool,
log_scale_parameter_names: Collection[str] | None,
) -> None:
self.timestamp_series.name = "timestamp"

self.runtime_series = self._get_runtime_series(fillna=runtime_fillna)
self.runtime_series.name = "runtimes"

parameter_names = get_lcbench_parameter_names()
if log_scale_parameter_names is None:
log_scale_parameter_names = get_lcbench_log_scale_parameter_names()

if len(log_scale_parameter_names) > 0:
if unrecognized_param_set := (
set(log_scale_parameter_names) - set(parameter_names)
):
raise ValueError(f"Unrecognized columns: {unrecognized_param_set}")
self.parameter_df[log_scale_parameter_names] = self.parameter_df[
log_scale_parameter_names
].transform("log")

self.parameter_df = self.parameter_df[parameter_names]

@staticmethod
def _unstack_by_epoch(series: pd.Series) -> pd.DataFrame:
# unstack by epoch and truncate 52 epochs [0, ..., 51]
# to 50 epochs [1, ..., 50]
return series.unstack(level="epoch").iloc[:, 1:-1]

def _get_runtime_series(self, fillna: bool) -> pd.Series:
# timestamp (in secs) at every epoch, grouped by trial
timestamps_grouped = self.timestamp_series.groupby(level="trial")

# runtime (in secs) of training each incremental epoch
runtime_series = timestamps_grouped.diff(periods=1) # first element is NaN
if fillna:
runtime_series.fillna(timestamps_grouped.head(n=1), inplace=True)

return runtime_series

def _to_tensor(
self,
x: pd.DataFrame | pd.Series,
) -> torch.Tensor:
return torch.from_numpy(x.values).to(dtype=self.dtype, device=self.device)

@property
def metric_df(self) -> pd.DataFrame:
return self._unstack_by_epoch(self.metric_series)

@property
def runtime_df(self) -> pd.DataFrame:
return self._unstack_by_epoch(self.runtime_series)

@property
def average_runtime_series(self) -> pd.Series:
# take average runtime over epochs (N6231489 shows runtime is
# mostly constant across epochs, as one'd expect)
return self.runtime_series.groupby(level="trial").mean()

@property
def parameters(self) -> torch.Tensor:
return self._to_tensor(self.parameter_df)

@property
def metrics(self) -> torch.Tensor:
return self._to_tensor(self.metric_df)

@property
def runtimes(self) -> torch.Tensor:
return self._to_tensor(self.runtime_df)

@property
def average_runtimes(self) -> torch.Tensor:
return self._to_tensor(self.average_runtime_series)


def load_lcbench_data(
dataset_name: str,
metric_name: str = DEFAULT_METRIC_NAME,
log_scale_parameter_names: Collection[str] | None = None,
dtype: torch.dtype = torch.double,
device: torch.device | None = None,
) -> LCBenchData:
if dataset_name not in DATASET_NAMES:
raise ValueError(
f"Invalid dataset {dataset_name}. Valid datasets: {DATASET_NAMES}"
)

parameter_df = LCBenchDataLoader(dataset_name, stem="config").load()
metrics_df = LCBenchDataLoader(dataset_name, stem="metrics").load()

return LCBenchData(
parameter_df=parameter_df,
metric_series=metrics_df[metric_name],
timestamp_series=metrics_df["time"],
log_scale_parameter_names=log_scale_parameter_names,
dtype=dtype,
device=device,
)
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Binary file not shown.
Loading

0 comments on commit 981c50b

Please sign in to comment.