Skip to content

Commit

Permalink
Refactor injector (#5)
Browse files Browse the repository at this point in the history
  • Loading branch information
mariotaddeucci authored Nov 12, 2024
1 parent bf14a79 commit 94eae3a
Show file tree
Hide file tree
Showing 12 changed files with 213 additions and 307 deletions.
19 changes: 11 additions & 8 deletions examples/simple_app.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,20 @@ def get_users_urls():
]


@gyjd(retry_attempts=5)
def get_json(url: str, random_exception: bool = False, logger: Logger = None):
@gyjd(retry_attempts=10)
def get_json(url: str, logger: Logger = None):
logger.debug(f"GET {url}")
if random_exception and random.random() < 0.5:
raise ValueError("Random exception")

if random.random() < 0.5:
logger.warning("Random failure activate")
url = "https://httpbin.org/status/500"

response = requests.get(url)
response.raise_for_status()
return response.json()


@gyjd.command(alias="parallel_requests")
@gyjd
def example_parallel_requests(
strategy: str,
logger: Logger = None,
Expand All @@ -46,13 +48,14 @@ def example_parallel_requests(

end_at = time.monotonic()
elapsed = end_at - start_at
logger.info("10 requests completed with at least 1.5 seconds of delay, sequentially would take at least 15 seconds")
if elapsed >= 15:
logger.info("20 requests completed with at least 1.5 seconds of delay, sequentially would take at least 15 seconds")
logger.info("Random failures are expected, retrying up to 10 times")
if elapsed >= 30:
logger.warning(f"Elapsed {elapsed:.2f}s is greater than 15 seconds")
return

logger.info(f"Elapsed {elapsed:.2f}s, whahoo!")


if __name__ == "__main__":
gyjd.run()
example_parallel_requests("thread_map")
4 changes: 1 addition & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ dynamic = ["version"]
description = "Add your description here"
readme = "README.md"
requires-python = ">=3.11"
dependencies = [
"typer>=0.12.5"
]
dependencies = []

[project.optional-dependencies]
compiler = ["nuitka"]
Expand Down
53 changes: 31 additions & 22 deletions src/gyjd/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,15 +4,28 @@
from functools import partial

from gyjd.config import LoggerConfig
from gyjd.core.cli import CLI
from gyjd.core.config_loader import load_config_file
from gyjd.core.gyjd_callable import GYJDCallable
from gyjd.core.logger import GYJDLogger, get_default_logger
from gyjd.core.simple_injector import inject_dependencies, register_dependency
from gyjd.core.simple_injector import clear_registered_dependencies, inject_dependencies, register_dependency

register_dependency(get_default_logger, cls=GYJDLogger, singleton=True, if_exists="skip")
register_dependency(get_default_logger, cls=logging.Logger, singleton=True, if_exists="skip")
register_dependency(LoggerConfig, singleton=True, if_exists="skip")

def setup_defaults(clear_dependencies: bool = False):
"""
Register default dependencies:
- GYJDLogger
- logging.Logger
- LoggerConfig
If clear_dependencies is True, clear all registered dependencies before registering the default ones.
"""

if clear_dependencies:
clear_registered_dependencies()

register_dependency(get_default_logger, cls=GYJDLogger, reuse_times=-1, if_exists="skip")
register_dependency(get_default_logger, cls=logging.Logger, reuse_times=-1, if_exists="skip")
register_dependency(LoggerConfig, reuse_times=-1, if_exists="skip")


class gyjd:
Expand All @@ -30,7 +43,16 @@ def __new__(
retry_on_exceptions=(Exception,),
) -> GYJDCallable:
if func is None:
return gyjd
wrapper = partial(
gyjd,
return_exception_on_fail=return_exception_on_fail,
retry_attempts=retry_attempts,
retry_delay=retry_delay,
retry_max_delay=retry_max_delay,
retry_backoff=retry_backoff,
retry_on_exceptions=retry_on_exceptions,
)
return wrapper

return GYJDCallable(
func=inject_dependencies(func),
Expand All @@ -42,16 +64,6 @@ def __new__(
retry_on_exceptions=retry_on_exceptions,
)

@classmethod
def command(cls, func: Callable | None = None, *, alias=None):
if func is None:
return partial(cls.command, alias=alias)

alias = alias or getattr(func, "__name__", None)
CLI.registry(inject_dependencies(func), alias)

return func

@classmethod
def _collect_children_config(cls, dataclass_type: type, subtree: str = ""):
for field in fields(dataclass_type):
Expand Down Expand Up @@ -84,7 +96,7 @@ def register_config_file(
subtree=subtree.split("."),
),
cls=config_type,
singleton=True,
reuse_times=-1,
)

for child_subtree, child_type in cls._collect_children_config(config_type):
Expand All @@ -95,13 +107,10 @@ def register_config_file(
subtree=child_subtree.split("."),
),
cls=child_type,
singleton=True,
reuse_times=-1,
if_exists="overwrite",
)

@classmethod
def run(cls):
CLI.run()


setup_defaults()
__all__ = ["gyjd"]
95 changes: 0 additions & 95 deletions src/gyjd/__main__.py

This file was deleted.

47 changes: 0 additions & 47 deletions src/gyjd/core/cli.py

This file was deleted.

5 changes: 4 additions & 1 deletion src/gyjd/core/config_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,10 @@ def load_config_from_toml_file(filepath: str, subtree: list[str] | str | None =


def load_config_file(
config_type: type[T], filepath: str, allow_if_file_not_found: bool = False, subtree: list[str] | str | None = None
config_type: type[T],
filepath: str,
allow_if_file_not_found: bool = False,
subtree: list[str] | str | None = None,
):
try:
data = load_config_from_toml_file(filepath, subtree)
Expand Down
Loading

0 comments on commit 94eae3a

Please sign in to comment.