Skip to content

Commit

Permalink
Add allow_glob to dataclass argparse
Browse files Browse the repository at this point in the history
  • Loading branch information
fleonce committed Jan 24, 2025
1 parent a66972f commit 7594b69
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 18 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[project]
name = "with-argparse"
description = "A simple but handy Python library to generate a `argparse.ArgumentParser` object from a type-annotated method "
version = "1.0.1"
version = "1.0.2"
license = {text = "Apache 2.0"}
readme = "README.md"

Expand Down
14 changes: 14 additions & 0 deletions test/test_argparse.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,26 @@
import logging
import unittest
from typing import Optional, Set

from tools import sys_args, foreach
from with_argparse import with_argparse

logging.basicConfig(
level="DEBUG"
)


class ArgparseTestCase(unittest.TestCase):

def test_multi_type_union(self):
@with_argparse
def func(inp: int | float):
return inp

value = 123.456
with sys_args(inp=value):
self.assertEqual(value, func())

def test_python3_11_type_union(self):
@with_argparse
def func(inp: int | None):
Expand Down
40 changes: 26 additions & 14 deletions with_argparse/configure_argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import warnings
from argparse import ArgumentParser
from dataclasses import dataclass, MISSING
from functools import partial
from pathlib import Path
from types import NoneType, UnionType, GenericAlias
from types import NoneType, UnionType
from typing import (
Any, Set, List, get_origin, get_args, Union, Literal, Optional, Sequence, TypeVar, Iterable,
Callable, MutableMapping, Mapping
)
from with_argparse.utils import glob_to_path_list, flatten
from with_argparse.utils import flatten, glob_to_paths

SET_TYPES = {set, Set}
LIST_TYPES = {list, List}
Expand Down Expand Up @@ -128,6 +129,7 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]):

parsed_args = self.argparse.parse_args()
args_dict = self._apply_name_mapping(parsed_args.__dict__, None)
args_dict = self._apply_post_parse_conversions(args_dict, dict())
return self.func(self.dataclass(**args_dict), **kwargs)

def call(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
Expand Down Expand Up @@ -235,16 +237,7 @@ def _call_func(self, args: Sequence[Any], kwargs: Mapping[str, Any]):
overriden_kwargs[field_name] = kwargs[field_name]

args_dict = self._apply_name_mapping(parsed_args.__dict__, args_dict)
for key, conversion_functions in self.post_parse_type_conversions.items():
initial_value = args_dict[key]
if initial_value is None:
args_dict[key] = initial_value
continue

value = initial_value
for conversion_func in conversion_functions:
value = conversion_func(value)
args_dict[key] = value
args_dict = self._apply_post_parse_conversions(args_dict, args_dict)

for key, value in overriden_kwargs.items():
if key in args_dict:
Expand All @@ -269,6 +262,25 @@ def _apply_name_mapping(
out[key] = value
return out

def _apply_post_parse_conversions(
self,
parsed_args: Mapping[str, Any],
out: MutableMapping[str, Any] | None
) -> MutableMapping[str, Any]:
out = out or dict()
out.update(parsed_args)
for key, conversion_functions in self.post_parse_type_conversions.items():
initial_value = parsed_args[key]
if initial_value is None:
out[key] = initial_value
continue

value = initial_value
for conversion_func in conversion_functions:
value = conversion_func(value)
out[key] = value
return out

def _setup_argument(
self,
arg_name: str,
Expand Down Expand Up @@ -521,7 +533,7 @@ def first_working_inner_type(inp):
else:
orig_arg_name = self._resolve_orig_arg_name(arg_name)
if (
arg_type == Path
arg_type in {Path, str}
and orig_arg_name in self.allow_glob
):
self._register_post_parse_type_conversion(
Expand All @@ -531,7 +543,7 @@ def first_working_inner_type(inp):

return _Argument(
arg_name,
glob_to_path_list,
partial(glob_to_paths, func=arg_type),
arg_default,
arg_required,
False,
Expand Down
2 changes: 2 additions & 0 deletions with_argparse/impl.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def set_config(key: str, state: bool):
def with_dataclass(
*,
dataclass=None,
allow_glob: Optional[set[str]] = None,
):
def wrapper(fn):
@functools.wraps(fn)
Expand All @@ -61,6 +62,7 @@ def inner(*inner_args, **inner_kwargs):

parser = WithArgparse(
(dataclass, fn),
allow_glob=allow_glob,
)
return parser.call(inner_args, inner_kwargs)
return inner
Expand Down
7 changes: 4 additions & 3 deletions with_argparse/utils.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from glob import glob
from pathlib import Path
from typing import Any
from typing import Any, TypeVar, Callable


def flatten(input: list[list[Any]]) -> list[Any]:
return [a for b in input for a in b]

T = TypeVar('T')

def glob_to_path_list(input: str) -> list[Path]:
return list(map(Path, glob(input)))
def glob_to_paths(inp: str, func: Callable[[str], T]) -> list[T]:
return list(map(func, glob(inp)))

0 comments on commit 7594b69

Please sign in to comment.