diff --git a/pyproject.toml b/pyproject.toml index 92b4a4a..16382d5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,7 +1,7 @@ [project] name = "with-argparse" description = "A simple but handy Python library to generate a `argparse.ArgumentParser` object from a type-annotated method " -version = "1.0.1" +version = "1.0.2" license = {text = "Apache 2.0"} readme = "README.md" diff --git a/test/test_argparse.py b/test/test_argparse.py index 6e756c4..53834ee 100644 --- a/test/test_argparse.py +++ b/test/test_argparse.py @@ -1,12 +1,26 @@ +import logging import unittest from typing import Optional, Set from tools import sys_args, foreach from with_argparse import with_argparse +logging.basicConfig( + level="DEBUG" +) + class ArgparseTestCase(unittest.TestCase): + def test_multi_type_union(self): + @with_argparse + def func(inp: int | float): + return inp + + value = 123.456 + with sys_args(inp=value): + self.assertEqual(value, func()) + def test_python3_11_type_union(self): @with_argparse def func(inp: int | None): diff --git a/with_argparse/configure_argparse.py b/with_argparse/configure_argparse.py index c5343d0..fe3c29f 100644 --- a/with_argparse/configure_argparse.py +++ b/with_argparse/configure_argparse.py @@ -5,13 +5,14 @@ import warnings from argparse import ArgumentParser from dataclasses import dataclass, MISSING +from functools import partial from pathlib import Path -from types import NoneType, UnionType, GenericAlias +from types import NoneType, UnionType from typing import ( Any, Set, List, get_origin, get_args, Union, Literal, Optional, Sequence, TypeVar, Iterable, Callable, MutableMapping, Mapping ) -from with_argparse.utils import glob_to_path_list, flatten +from with_argparse.utils import flatten, glob_to_paths SET_TYPES = {set, Set} LIST_TYPES = {list, List} @@ -128,6 +129,7 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]): parsed_args = self.argparse.parse_args() args_dict = self._apply_name_mapping(parsed_args.__dict__, None) + args_dict = self._apply_post_parse_conversions(args_dict, dict()) return self.func(self.dataclass(**args_dict), **kwargs) def call(self, args: Sequence[Any], kwargs: Mapping[str, Any]): @@ -235,16 +237,7 @@ def _call_func(self, args: Sequence[Any], kwargs: Mapping[str, Any]): overriden_kwargs[field_name] = kwargs[field_name] args_dict = self._apply_name_mapping(parsed_args.__dict__, args_dict) - for key, conversion_functions in self.post_parse_type_conversions.items(): - initial_value = args_dict[key] - if initial_value is None: - args_dict[key] = initial_value - continue - - value = initial_value - for conversion_func in conversion_functions: - value = conversion_func(value) - args_dict[key] = value + args_dict = self._apply_post_parse_conversions(args_dict, args_dict) for key, value in overriden_kwargs.items(): if key in args_dict: @@ -269,6 +262,25 @@ def _apply_name_mapping( out[key] = value return out + def _apply_post_parse_conversions( + self, + parsed_args: Mapping[str, Any], + out: MutableMapping[str, Any] | None + ) -> MutableMapping[str, Any]: + out = out or dict() + out.update(parsed_args) + for key, conversion_functions in self.post_parse_type_conversions.items(): + initial_value = parsed_args[key] + if initial_value is None: + out[key] = initial_value + continue + + value = initial_value + for conversion_func in conversion_functions: + value = conversion_func(value) + out[key] = value + return out + def _setup_argument( self, arg_name: str, @@ -521,7 +533,7 @@ def first_working_inner_type(inp): else: orig_arg_name = self._resolve_orig_arg_name(arg_name) if ( - arg_type == Path + arg_type in {Path, str} and orig_arg_name in self.allow_glob ): self._register_post_parse_type_conversion( @@ -531,7 +543,7 @@ def first_working_inner_type(inp): return _Argument( arg_name, - glob_to_path_list, + partial(glob_to_paths, func=arg_type), arg_default, arg_required, False, diff --git a/with_argparse/impl.py b/with_argparse/impl.py index 0aa471b..3fb814f 100644 --- a/with_argparse/impl.py +++ b/with_argparse/impl.py @@ -52,6 +52,7 @@ def set_config(key: str, state: bool): def with_dataclass( *, dataclass=None, + allow_glob: Optional[set[str]] = None, ): def wrapper(fn): @functools.wraps(fn) @@ -61,6 +62,7 @@ def inner(*inner_args, **inner_kwargs): parser = WithArgparse( (dataclass, fn), + allow_glob=allow_glob, ) return parser.call(inner_args, inner_kwargs) return inner diff --git a/with_argparse/utils.py b/with_argparse/utils.py index bbfb5b4..b84e759 100644 --- a/with_argparse/utils.py +++ b/with_argparse/utils.py @@ -1,11 +1,12 @@ from glob import glob from pathlib import Path -from typing import Any +from typing import Any, TypeVar, Callable def flatten(input: list[list[Any]]) -> list[Any]: return [a for b in input for a in b] +T = TypeVar('T') -def glob_to_path_list(input: str) -> list[Path]: - return list(map(Path, glob(input))) +def glob_to_paths(inp: str, func: Callable[[str], T]) -> list[T]: + return list(map(func, glob(inp)))