Skip to content

Commit

Permalink
Resilient run algo (#27)
Browse files Browse the repository at this point in the history
* feat: more resilient OOM handling

* fix: remove max_inner_num_threads

* test: expand coverage of app.py and config.py
  • Loading branch information
dPys authored Dec 27, 2024
1 parent de8ac89 commit c41bcda
Show file tree
Hide file tree
Showing 4 changed files with 267 additions and 79 deletions.
13 changes: 1 addition & 12 deletions nxbench/backends/registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,9 @@ def teardown_networkx():
# ---- Nx-Parallel backend ----
def convert_parallel(original_graph: nx.Graph, num_threads: int):
nxp = import_module("nx_parallel")
from multiprocessing import cpu_count

total_cores = cpu_count()

n_jobs = min(num_threads, total_cores)

nx.config.backends.parallel.active = True
nx.config.backends.parallel.n_jobs = n_jobs
nx.config.backends.parallel.n_jobs = num_threads
nx.config.backends.parallel.backend = "loky"
if hasattr(nx.config.backends.parallel, "inner_max_num_threads"):
nx.config.backends.parallel.inner_max_num_threads = max(
total_cores // n_jobs, 1
)

return nxp.ParallelGraph(original_graph)


Expand Down
39 changes: 25 additions & 14 deletions nxbench/benchmarking/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,24 +117,34 @@ def run_algorithm(
original_env[var_name] = os.environ.get(var_name)
os.environ[var_name] = str(num_thread)

# start memory tracking
with memory_tracker() as mem:
start_time = time.perf_counter()
# pass the graph plus the processed pos_args and kwargs
result = algo_func(graph, *pos_args, **kwargs)
end_time = time.perf_counter()

execution_time = end_time - start_time
peak_memory = mem["peak"]
logger.debug(f"Algorithm '{algo_config.name}' executed successfully.")

except Exception as e:
logger.exception("Algorithm run failed")
execution_time = time.perf_counter() - start_time
peak_memory = mem.get("peak", 0)
result = None
error = str(e)
try:
result = algo_func(graph, *pos_args, **kwargs)
except NotImplementedError as nie:
logger.info(
f"Skipping {algo_config.name} for backend '{backend}' "
"because it's not implemented (NotImplementedError)."
)
return None, 0.0, 0, str(nie)
except MemoryError as me:
# gracefully handle OOM:
logger.exception("Algorithm ran out of memory.")
result = None
error = f"MemoryError: {me}"
except Exception as e:
logger.exception("Algorithm run failed unexpectedly.")
result = None
error = str(e)
finally:
end_time = time.perf_counter()
execution_time = end_time - start_time

peak_memory = mem.get("peak", 0)

finally:
logger.debug(f"Algorithm '{algo_config.name}' executed successfully.")
# restore environment variables
for var_name in vars_to_set:
if original_env[var_name] is None:
Expand Down Expand Up @@ -309,6 +319,7 @@ def create_benchmark_subflow(name_suffix: str, resource_type: str, num_thread: i
"resources": {resource_type: 1},
"threads_per_worker": num_thread,
"processes": False,
"memory_limit": "2GB",
}
),
)
Expand Down
131 changes: 128 additions & 3 deletions nxbench/benchmarking/tests/test_config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import math
import os
import textwrap
from functools import partial
from unittest.mock import MagicMock, patch

import networkx as nx
Expand Down Expand Up @@ -93,6 +94,84 @@ def test_invalid_validation_function(self, caplog):
in caplog.text
)

def test_get_callable_non_networkx_backend_returns_partial(self):
"""
Ensure get_callable() returns a functools.partial
if the backend is not 'networkx'.
"""
with patch("builtins.__import__") as mock_import:
mock_module = MagicMock()
mock_func = MagicMock()
mock_import.return_value = mock_module
mock_module.some_function = mock_func

algo = AlgorithmConfig(
name="test_algo",
func="my_module.some_function",
)
# ensures get_func_ref() is valid
func_ref = algo.get_func_ref()
assert func_ref is not None

partial_func = algo.get_callable(backend_name="igraph")
assert isinstance(
partial_func, partial
), "Should return a partial for non-networkx backends"
assert (
partial_func.func == mock_func
), "Partial should wrap the imported function"
assert partial_func.keywords["backend"] == "igraph"

def test_get_callable_raises_importerror_if_func_is_none(self):
"""If get_func_ref() returns None, get_callable() must raise ImportError."""
# force a bad function import
algo = AlgorithmConfig(
name="broken_algo",
func="nonexistent.module.func",
)
with pytest.raises(
ImportError, match="could not be imported for algorithm 'broken_algo'"
):
algo.get_callable(backend_name="networkx")

def test_get_validate_ref_none_when_validate_result_is_none(self):
"""If validate_result is not provided, get_validate_ref() should immediately
return None.
"""
algo = AlgorithmConfig(
name="test", func="some.module.function", validate_result=None
)
ref = algo.get_validate_ref()
assert (
ref is None
), "Should return None immediately if no validate_result is specified"

@pytest.mark.parametrize(
("requires_directed", "requires_undirected", "requires_weighted"),
[
(True, False, False),
(False, True, True),
(True, True, True),
],
)
def test_requires_attributes_instantiation(
self, requires_directed, requires_undirected, requires_weighted
):
"""
Instantiate AlgorithmConfig with the various booleans
for coverage on those attributes.
"""
algo = AlgorithmConfig(
name="test_attrs",
func="some.module.func",
requires_directed=requires_directed,
requires_undirected=requires_undirected,
requires_weighted=requires_weighted,
)
assert algo.requires_directed == requires_directed
assert algo.requires_undirected == requires_undirected
assert algo.requires_weighted == requires_weighted


class TestDatasetConfig:
def test_valid_initialization(self):
Expand Down Expand Up @@ -139,7 +218,7 @@ def test_load_from_valid_yaml(self, tmp_path):
assert config.datasets[0].name == "jazz"
assert config.datasets[0].source == "networkrepository"

# Verify machine_info
# verify machine_info
assert config.machine_info == {"cpu": "Intel i7", "ram": "16GB"}

def test_load_from_nonexistent_yaml(self):
Expand All @@ -160,9 +239,8 @@ def test_load_from_invalid_yaml_structure(self, tmp_path, caplog):

config = BenchmarkConfig.from_yaml(config_file)

# No valid algorithms loaded because 'pagerank' isn't in a list
# no valid algorithms loaded because 'pagerank' isn't in a list
assert len(config.algorithms) == 0
# Warnings about invalid structure
assert "should be a list" in caplog.text

def test_to_yaml(self, tmp_path):
Expand Down Expand Up @@ -203,6 +281,53 @@ def test_to_yaml(self, tmp_path):
assert "machine_info" in loaded_data
assert loaded_data["machine_info"] == {"cpu": "Intel i7", "ram": "16GB"}

def test_load_from_invalid_datasets_type(self, tmp_path, caplog):
"""
'datasets' should be a list, but if not, from_yaml should
log an error, set it to [], and continue.
"""
yaml_content = """
algorithms:
- name: pagerank
func: networkx.algorithms.link_analysis.pagerank_alg.pagerank
datasets:
jazz:
source: networkrepository
"""
config_file = tmp_path / "bad_datasets.yaml"
config_file.write_text(yaml_content)

config = BenchmarkConfig.from_yaml(config_file)
# expect no valid datasets loaded
assert (
len(config.datasets) == 0
), "datasets should have been forced to empty list"
assert (
"should be a list in the config file" in caplog.text.lower()
), "Expected an error log about 'datasets' not being a list"

def test_load_from_yaml_with_env_data(self, tmp_path):
"""Ensure 'environ' data is captured in config.env_data"""
yaml_content = """
algorithms:
- name: test_algo
func: some.module.function
datasets:
- name: ds
source: dummy
environ:
MY_ENV_VAR: "some_value"
OTHER_VAR: 123
"""
config_file = tmp_path / "env_data.yaml"
config_file.write_text(yaml_content)

config = BenchmarkConfig.from_yaml(config_file)
assert config.env_data == {
"MY_ENV_VAR": "some_value",
"OTHER_VAR": 123,
}, "env_data should match what's in 'environ' key from YAML"


class TestGlobalConfiguration:
def test_configure_with_instance(self):
Expand Down
Loading

0 comments on commit c41bcda

Please sign in to comment.