Skip to content

Commit

Permalink
feat: watchdog observer for config auto loading (#119)
Browse files Browse the repository at this point in the history
Signed-off-by: Nandita Koppisetty <[email protected]>
  • Loading branch information
nkoppisetty authored Apr 12, 2023
1 parent f8f2403 commit 2d49c4f
Show file tree
Hide file tree
Showing 27 changed files with 421 additions and 207 deletions.
4 changes: 2 additions & 2 deletions numaprom/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import logging
import os

from numaprom._config import UnifiedConf, MetricConf, ServiceConf, NumapromConf
from numaprom._config import UnifiedConf, MetricConf, AppConf, NumapromConf


def get_logger(name):
Expand All @@ -25,4 +25,4 @@ def get_logger(name):
return logger


__all__ = ["UnifiedConf", "MetricConf", "ServiceConf", "NumapromConf", "get_logger"]
__all__ = ["UnifiedConf", "MetricConf", "AppConf", "NumapromConf", "get_logger"]
6 changes: 3 additions & 3 deletions numaprom/_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,13 +26,13 @@ class MetricConf:


@dataclass
class ServiceConf:
service: str = "default"
class AppConf:
app: str = "default"
namespace: str = "default"
metric_configs: List[MetricConf] = field(default_factory=lambda: [MetricConf()])
unified_configs: List[UnifiedConf] = field(default_factory=list)


@dataclass
class NumapromConf:
configs: List[ServiceConf]
configs: List[AppConf]
3 changes: 3 additions & 0 deletions numaprom/_constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
NUMAPROM_DIR = os.path.dirname(__file__)
ROOT_DIR = os.path.split(NUMAPROM_DIR)[0]
TESTS_DIR = os.path.join(ROOT_DIR, "tests")
TESTS_RESOURCES = os.path.join(TESTS_DIR, "resources")
DATA_DIR = os.path.join(NUMAPROM_DIR, "data")
CONFIG_DIR = os.path.join(NUMAPROM_DIR, "configs")
DEFAULT_CONFIG_DIR = os.path.join(NUMAPROM_DIR, "default-configs")
Expand All @@ -17,3 +18,5 @@
INFERENCE_VTX_KEY = "inference"
THRESHOLD_VTX_KEY = "threshold"
POSTPROC_VTX_KEY = "postproc"

CONFIG_PATHS = ["./numaprom/configs", "./numaprom/default-configs"]
70 changes: 2 additions & 68 deletions numaprom/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,18 +14,15 @@
from botocore.session import get_session
from mlflow.entities.model_registry import ModelVersion
from mlflow.exceptions import RestException
from numalogic.config import NumalogicConf, PostprocessFactory
from numalogic.config import PostprocessFactory
from numalogic.models.threshold import SigmoidThreshold
from numalogic.registry import MLflowRegistry, ArtifactData
from omegaconf import OmegaConf
from pynumaflow.function import Messages, Message

from numaprom import get_logger, MetricConf, ServiceConf, NumapromConf, UnifiedConf
from numaprom import get_logger, MetricConf
from numaprom._constants import (
DEFAULT_TRACKING_URI,
DEFAULT_PROMETHEUS_SERVER,
CONFIG_DIR,
DEFAULT_CONFIG_DIR,
)
from numaprom.entities import TrainerPayload, StreamPayload
from numaprom.clients.prometheus import Prometheus
Expand Down Expand Up @@ -155,69 +152,6 @@ def save_model(
return version


def get_all_configs():
schema: NumapromConf = OmegaConf.structured(NumapromConf)

conf = OmegaConf.load(os.path.join(CONFIG_DIR, "config.yaml"))
given_configs = OmegaConf.merge(schema, conf).configs

conf = OmegaConf.load(os.path.join(DEFAULT_CONFIG_DIR, "config.yaml"))
default_configs = OmegaConf.merge(schema, conf).configs

conf = OmegaConf.load(os.path.join(DEFAULT_CONFIG_DIR, "numalogic_config.yaml"))
schema: NumalogicConf = OmegaConf.structured(NumalogicConf)
default_numalogic = OmegaConf.merge(schema, conf)

return given_configs, default_configs, default_numalogic


def get_service_config(metric: str, namespace: str):
given_configs, default_configs, default_numalogic = get_all_configs()

# search and load from given configs
service_config = list(filter(lambda conf: (conf.namespace == namespace), given_configs))

# if not search and load from default configs
if not service_config:
for _conf in default_configs:
if metric in _conf.unified_configs[0].unified_metrics:
service_config = [_conf]
break

# if not in default configs, initialize Namespace conf with default values
if not service_config:
service_config = OmegaConf.structured(ServiceConf)
else:
service_config = service_config[0]

# loading and setting default numalogic config
for metric_config in service_config.metric_configs:
if OmegaConf.is_missing(metric_config, "numalogic_conf"):
metric_config.numalogic_conf = default_numalogic

return service_config


def get_metric_config(metric: str, namespace: str) -> Optional[MetricConf]:
service_config = get_service_config(metric, namespace)
metric_config = list(
filter(lambda conf: (conf.metric == metric), service_config.metric_configs)
)
if not metric_config:
return service_config.metric_configs[0]
return metric_config[0]


def get_unified_config(metric: str, namespace: str) -> Optional[UnifiedConf]:
service_config = get_service_config(metric, namespace)
unified_config = list(
filter(lambda conf: (metric in conf.unified_metrics), service_config.unified_configs)
)
if not unified_config:
return None
return unified_config[0]


def fetch_data(
payload: TrainerPayload, metric_config: MetricConf, labels: dict, return_labels=None
) -> pd.DataFrame:
Expand Down
6 changes: 2 additions & 4 deletions numaprom/udf/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@
from numaprom.entities import Status, StreamPayload, Header
from numaprom.tools import (
load_model,
get_metric_config,
msg_forward,
)
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -73,9 +73,7 @@ def inference(_: str, datum: Datum) -> bytes:
return orjson.dumps(payload, option=orjson.OPT_SERIALIZE_NUMPY)

# Load config
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
numalogic_conf = metric_config.numalogic_conf

# Load inference model
Expand Down
11 changes: 3 additions & 8 deletions numaprom/udf/postprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,9 @@
from numaprom.clients.redis import get_redis_client
from numaprom.tools import (
msgs_forward,
get_unified_config,
get_metric_config,
WindowScorer,
)
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -130,9 +129,7 @@ def __construct_unified_payload(


def _publish(final_score: float, payload: StreamPayload) -> List[bytes]:
unified_config = get_unified_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
unified_config = ConfigManager().get_unified_config(payload.composite_keys)

publisher_json = __construct_publisher_payload(payload, final_score).as_json()
_LOGGER.info("%s - Payload sent to publisher: %s", payload.uuid, publisher_json)
Expand Down Expand Up @@ -181,9 +178,7 @@ def postprocess(_: str, datum: Datum) -> List[bytes]:
payload = StreamPayload(**orjson.loads(_in_msg))

# Load config
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
metric_config = ConfigManager().get_metric_config(payload.composite_keys)

_LOGGER.debug("%s - Received Payload: %r ", payload.uuid, payload)

Expand Down
7 changes: 3 additions & 4 deletions numaprom/udf/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

from numaprom import get_logger
from numaprom.entities import Status, StreamPayload, Header
from numaprom.tools import msg_forward, load_model, get_metric_config
from numaprom.tools import msg_forward, load_model
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand All @@ -19,9 +20,7 @@ def preprocess(_: str, datum: Datum) -> bytes:
_LOGGER.info("%s - Received Payload: %r ", payload.uuid, payload)

# Load config
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
preprocess_cfgs = metric_config.numalogic_conf.preprocess

# Load preprocess artifact
Expand Down
7 changes: 3 additions & 4 deletions numaprom/udf/threshold.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
conditional_forward,
calculate_static_thresh,
load_model,
get_metric_config,
)
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -44,9 +44,8 @@ def threshold(_: str, datum: Datum) -> list[tuple[str, bytes]]:
)

# Load config
metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
cm = ConfigManager()
metric_config = cm.get_metric_config(payload.composite_keys)
thresh_cfg = metric_config.numalogic_conf.threshold

# Check if payload needs static inference
Expand Down
7 changes: 5 additions & 2 deletions numaprom/udf/window.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@
from numaprom import get_logger
from numaprom.entities import StreamPayload, Status, Header
from numaprom.clients.redis import get_redis_client
from numaprom.tools import msg_forward, create_composite_keys, get_metric_config
from numaprom.tools import msg_forward, create_composite_keys
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -68,7 +69,9 @@ def window(_: str, datum: Datum) -> Optional[bytes]:
_start_time = time.perf_counter()
msg = orjson.loads(datum.value)

metric_config = get_metric_config(metric=msg["name"], namespace=msg["labels"]["namespace"])
metric_config = ConfigManager().get_metric_config(
{"name": msg["name"], "namespace": msg["labels"]["namespace"]}
)
win_size = metric_config.numalogic_conf.model.conf["seq_len"]
buff_size = int(os.getenv("BUFF_SIZE", 10 * win_size))

Expand Down
7 changes: 3 additions & 4 deletions numaprom/udsink/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from numaprom import get_logger
from numaprom.entities import TrainerPayload
from numaprom.clients.redis import get_redis_client
from numaprom.tools import get_metric_config, save_model, fetch_data
from numaprom.tools import save_model, fetch_data
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -103,9 +104,7 @@ def train(datums: List[Datum]) -> Responses:
responses.append(Response.as_success(_datum.id))
continue

metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)
metric_config = ConfigManager().get_metric_config(payload.composite_keys)
model_cfg = metric_config.numalogic_conf.model

train_df = fetch_data(
Expand Down
8 changes: 3 additions & 5 deletions numaprom/udsink/train_rollout.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
from numaprom import get_logger
from numaprom.entities import TrainerPayload
from numaprom.clients.redis import get_redis_client
from numaprom.tools import get_metric_config, save_model, fetch_data
from numaprom.tools import save_model, fetch_data
from numaprom.watcher import ConfigManager

_LOGGER = get_logger(__name__)

Expand Down Expand Up @@ -117,10 +118,7 @@ def train_rollout(datums: Iterator[Datum]) -> Responses:
responses.append(Response.as_success(_datum.id))
continue

metric_config = get_metric_config(
metric=payload.composite_keys["name"], namespace=payload.composite_keys["namespace"]
)

metric_config = ConfigManager().get_metric_config(payload.composite_keys)
model_cfg = metric_config.numalogic_conf.model

# ToDo: standardize the label name
Expand Down
Loading

0 comments on commit 2d49c4f

Please sign in to comment.