Skip to content

Commit

Permalink
Merge pull request #19 from fleonce/feat/from_dataclass
Browse files Browse the repository at this point in the history
Implement basic support for dataclasses via `@with_dataclass()`
  • Loading branch information
fleonce authored Nov 27, 2024
2 parents a334489 + e849c94 commit 9b2bdcc
Show file tree
Hide file tree
Showing 3 changed files with 72 additions and 4 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "with-argparse"
description = "A simple but handy Python library to generate a `argparse.ArgumentParser` object from a type-annotated method "
version = "1.0.0rc2"
version = "1.0.0rc3"
license = {text = "Apache 2.0"}
readme = "README.md"

Expand Down
56 changes: 53 additions & 3 deletions with_argparse/configure_argparse.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import dataclasses
import inspect
import logging
import typing
import warnings
from argparse import ArgumentParser
from dataclasses import dataclass
from dataclasses import dataclass, _MISSING_TYPE
from pathlib import Path
from types import NoneType
from typing import Any, Set, List, get_origin, get_args, Union, Literal, Optional, Sequence, TypeVar, Iterable, \
Expand Down Expand Up @@ -49,9 +51,12 @@ class WithArgparse:
allow_custom: Mapping[str, Callable[[Any], Any]]
allow_dispatch_custom: bool

func: Callable
dataclass: Optional[type]

def __init__(
self,
func,
func_or_dataclass: Union[Callable, tuple[type, Callable]],
aliases: Optional[Mapping[str, Sequence[str]]] = None,
ignore_rename: Optional[set[str]] = None,
ignore_keys: Optional[set[str]] = None,
Expand All @@ -68,7 +73,16 @@ def __init__(
self.allow_custom = allow_custom or dict()
self.allow_dispatch_custom = True

self.func = func
if isinstance(func_or_dataclass, tuple):
if not inspect.isclass(func_or_dataclass[0]):
raise ValueError("First argument must be a type")
if not dataclasses.is_dataclass(func_or_dataclass[0]):
raise ValueError("First argument must be a dataclass")
self.dataclass = func_or_dataclass[0]
self.func = func_or_dataclass[1]
else:
self.func = func_or_dataclass
self.dataclass = None
self.argparse = ArgumentParser()

def _register_mapping(self): ...
Expand All @@ -86,7 +100,43 @@ def _register_post_parse_type_conversion(self, key: str, func: Callable[[Any], A
logger.debug(f"Registering post parse type conversion for {key}: {func.__name__} ({func})")
self.post_parse_type_conversions[key].append(func)

def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
if args:
raise ValueError("Positional argument overrides are not supported, yet")
if kwargs:
raise ValueError("Keyword argument overrides are not supported, yet")
if self.dataclass is None:
raise ValueError("self.dataclass cannot be None")

field_hints = typing.get_type_hints(self.dataclass)
for field in dataclasses.fields(self.dataclass):
field_required = isinstance(field.default, _MISSING_TYPE)
field_default = field.default if not field_required else None
field_type = field.type
if isinstance(field_type, str):
field_type = field_hints.get(field.name)
if not isinstance(field_type, type):
raise ValueError(f"Cannot determine type of {field.name}, got {field_type}")
self._setup_argument(
field.name,
field_type,
field_default,
field_required,
)

parsed_args = self.argparse.parse_args()
return self.func(self.dataclass(**parsed_args.__dict__))

def call(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
if self.dataclass is not None:
return self._call_dataclass(args, kwargs)
elif self.func is not None:
return self._call_func(args, kwargs)
else:
raise ValueError("self.dataclass and self.func cannot both be None")


def _call_func(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
info = inspect.getfullargspec(self.func)
callable_args = info.args or []
callable_kwonly = info.kwonlyargs or []
Expand Down
18 changes: 18 additions & 0 deletions with_argparse/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,24 @@ def set_config(key: str, state: bool):
config[key] = state


def with_dataclass(
*,
dataclass=None,
):
def wrapper(fn):
@functools.wraps(fn)
def inner(*inner_args, **inner_kwargs):
if not is_enabled():
return fn(*inner_args, **inner_kwargs)

parser = WithArgparse(
(dataclass, fn),
)
return parser.call(inner_args, inner_kwargs)
return inner
return wrapper


@overload
def with_argparse(
*,
Expand Down

0 comments on commit 9b2bdcc

Please sign in to comment.