Skip to content

Commit

Permalink
Context manager support for Services (#554)
Browse files Browse the repository at this point in the history
Work towards making the setup/teardown work for the SshService managed
by a Context instead of global.

To do this, we introspect all of the methods registered for a Service
mixin (which we also renamed from `self._services` to
`self._service_methods`), and pull out their unique Service instance
references from the `__self__` attribute in the bound method that was
registered, then store them as `self._services`.
This is to handle the case where a single method name was overridden
multiple times.
At this point, we have enough information to invoke a sort of
`__enter__` context for each of those Service instances.
However, since we use `__enter__` to invoke all of them, we instead need
to implement a separate `_enter_context()` method that does the actual
work, and can then be overridden by the subclasses.
 
Unfortunately, this technique make tamper with the order the Services
were layered on one another, so any logic that required relying on that
order for context setup might be problematic. I don't know if that will
actually be an issue, so for now, I'm ignoring the issue and just
documenting it.

This PR does not include explicit new tests for this, though does have
to fixup a few mocks in order to make sure that we're registered bound
methods and not just functions.

#510 will add additional tests after this is merged.
  • Loading branch information
bpkroth authored Oct 25, 2023
1 parent 7e83836 commit 156dd4b
Show file tree
Hide file tree
Showing 3 changed files with 134 additions and 6 deletions.
15 changes: 15 additions & 0 deletions mlos_bench/mlos_bench/environments/base_environment.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,7 @@ def __init__(self,
self.name = name
self.config = config
self._service = service
self._service_context: Optional[Service] = None
self._is_ready = False
self._in_context = False
self._const_args: Dict[str, TunableValue] = config.get("const_args", {})
Expand Down Expand Up @@ -216,6 +217,8 @@ def __enter__(self) -> 'Environment':
"""
_LOG.debug("Environment START :: %s", self)
assert not self._in_context
if self._service:
self._service_context = self._service.__enter__()
self._in_context = True
return self

Expand All @@ -225,13 +228,25 @@ def __exit__(self, ex_type: Optional[Type[BaseException]],
"""
Exit the context of the benchmarking environment.
"""
ex_throw = None
if ex_val is None:
_LOG.debug("Environment END :: %s", self)
else:
assert ex_type and ex_val
_LOG.warning("Environment END :: %s", self, exc_info=(ex_type, ex_val, ex_tb))
assert self._in_context
if self._service_context:
try:
self._service_context.__exit__(ex_type, ex_val, ex_tb)
# pylint: disable=broad-exception-caught
except Exception as ex:
_LOG.error("Exception while exiting Service context '%s': %s", self._service, ex)
ex_throw = ex
finally:
self._service_context = None
self._in_context = False
if ex_throw:
raise ex_throw
return False # Do not suppress exceptions

def __str__(self) -> str:
Expand Down
103 changes: 97 additions & 6 deletions mlos_bench/mlos_bench/services/base_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
import json
import logging

from typing import Any, Callable, Dict, List, Optional, Union
from types import TracebackType
from typing import Any, Callable, Dict, List, Optional, Set, Type, Union
from typing_extensions import Literal

from mlos_bench.config.schemas import ConfigSchema
from mlos_bench.services.types.config_loader_type import SupportsConfigLoading
Expand Down Expand Up @@ -79,13 +81,33 @@ def __init__(self,
self.config = config or {}
self._validate_json_config(self.config)
self._parent = parent
self._services: Dict[str, Callable] = {}
self._service_methods: Dict[str, Callable] = {}

if parent:
self.register(parent.export())
if methods:
self.register(methods)

# In order to get a list of all child contexts, we need to look at only
# the bound methods that were not overridden by another mixin.
# Then we inspect the internally bound __self__ variable to discover
# which Service instance that method belongs too.
# To do this we also

self._services: Set[Service] = {
# Enumerate the Services that are bound to this instance in the
# order they were added.
# Unfortunately, by creating a set, we may destroy the ability to
# preserve the context enter/exit order, but hopefully it doesn't
# matter.
svc_method.__self__ for _, svc_method in self._service_methods.items()
# Note: some methods are actually stand alone functions, so we need
# to filter them out.
if hasattr(svc_method, '__self__') and isinstance(svc_method.__self__, Service)
}
self._service_contexts: List[Service] = []
self._in_context = False

self._config_loader_service: SupportsConfigLoading
if parent and isinstance(parent, SupportsConfigLoading):
self._config_loader_service = parent
Expand Down Expand Up @@ -117,6 +139,75 @@ def merge_methods(ext_methods: Union[Dict[str, Callable], List[Callable], None],
local_methods.update(ext_methods)
return local_methods

def __enter__(self) -> "Service":
"""
Enter the Service mix-in context.
Calls the _enter_context() method of all the Services registered under this one.
"""
if self._in_context:
# Multiple environments can share the same Service, so we need to
# add a check and make this a re-entrant Service context.
assert self._service_contexts
assert all(svc._in_context for svc in self._services)
return self
self._service_contexts = [svc._enter_context() for svc in self._services]
self._in_context = True
return self

def __exit__(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exit the Service mix-in context.
Calls the _exit_context() method of all the Services registered under this one.
"""
if not self._in_context:
# Multiple environments can share the same Service, so we need to
# add a check and make this a re-entrant Service context.
assert not self._service_contexts
assert all(not svc._in_context for svc in self._services)
return False
ex_throw = None
for svc in reversed(self._service_contexts):
try:
svc._exit_context(ex_type, ex_val, ex_tb)
# pylint: disable=broad-exception-caught
except Exception as ex:
_LOG.error("Exception while exiting Service context '%s': %s", svc, ex)
ex_throw = ex
self._service_contexts = []
if ex_throw:
raise ex_throw
self._in_context = False
return False

def _enter_context(self) -> "Service":
"""
Enters the context for this particular Service instance.
Called by the base __enter__ method of the Service class so it can be
used with mix-ins and overridden by subclasses.
"""
assert not self._in_context
self._in_context = True
return self

def _exit_context(self, ex_type: Optional[Type[BaseException]],
ex_val: Optional[BaseException],
ex_tb: Optional[TracebackType]) -> Literal[False]:
"""
Exits the context for this particular Service instance.
Called by the base __enter__ method of the Service class so it can be
used with mix-ins and overridden by subclasses.
"""
# pylint: disable=unused-argument
assert self._in_context
self._in_context = False
return False

def _validate_json_config(self, config: dict) -> None:
"""
Reconstructs a basic json config that this class might have been
Expand All @@ -143,7 +234,7 @@ def pprint(self) -> str:
"""
return f"{self} ::\n" + "\n".join(
f' "{key}": {getattr(val, "__self__", "stand-alone")}'
for (key, val) in self._services.items()
for (key, val) in self._service_methods.items()
)

@property
Expand All @@ -170,8 +261,8 @@ def register(self, services: Union[Dict[str, Callable], List[Callable]]) -> None
if not isinstance(services, dict):
services = {svc.__name__: svc for svc in services}

self._services.update(services)
self.__dict__.update(self._services)
self._service_methods.update(services)
self.__dict__.update(self._service_methods)

if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Added methods to: %s", self.pprint())
Expand All @@ -188,4 +279,4 @@ def export(self) -> Dict[str, Callable]:
if _LOG.isEnabledFor(logging.DEBUG):
_LOG.debug("Export methods from: %s", self.pprint())

return self._services
return self._service_methods
22 changes: 22 additions & 0 deletions mlos_bench/mlos_bench/tests/environments/local/local_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,28 @@ def test_local_env(tunable_groups: TunableGroups) -> None:
)


def test_local_env_service_context(tunable_groups: TunableGroups) -> None:
"""
Basic check that context support for Service mixins are handled when environment contexts are entered.
"""
local_env = create_local_env(tunable_groups, {
"run": ["echo NA"]
})
# pylint: disable=protected-access
assert local_env._service
assert not local_env._service._in_context
assert not local_env._service._service_contexts
with local_env as env_context:
assert env_context._in_context
assert local_env._service._in_context
assert local_env._service._service_contexts # type: ignore[unreachable] # (false positive)
assert all(svc._in_context for svc in local_env._service._service_contexts)
assert all(svc._in_context for svc in local_env._service._services)
assert not local_env._service._in_context # type: ignore[unreachable] # (false positive)
assert not local_env._service._service_contexts
assert not any(svc._in_context for svc in local_env._service._services)


def test_local_env_results_no_header(tunable_groups: TunableGroups) -> None:
"""
Fail if the results are not in the expected format.
Expand Down

0 comments on commit 156dd4b

Please sign in to comment.