From c178b1e415582e76ca8b9c866004e87652a60c6d Mon Sep 17 00:00:00 2001 From: fleonce Date: Wed, 27 Nov 2024 10:40:39 +0100 Subject: [PATCH 1/2] Fix mypy typing error in new dataclass call --- with_argparse/configure_argparse.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/with_argparse/configure_argparse.py b/with_argparse/configure_argparse.py index 5402b8e..c9354e9 100644 --- a/with_argparse/configure_argparse.py +++ b/with_argparse/configure_argparse.py @@ -1,6 +1,7 @@ import dataclasses import inspect import logging +import typing import warnings from argparse import ArgumentParser from dataclasses import dataclass, _MISSING_TYPE @@ -107,12 +108,18 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]): if self.dataclass is None: raise ValueError("self.dataclass cannot be None") + field_hints = typing.get_type_hints(self.dataclass) for field in dataclasses.fields(self.dataclass): field_required = isinstance(field.default, _MISSING_TYPE) field_default = field.default if not field_required else None + field_type = field.type + if isinstance(field_type, str): + field_type = field_hints.get(field.name) + if not isinstance(field_type, type): + raise ValueError(f"Cannot determine type of {field.name}, got {field_type}") self._setup_argument( field.name, - field.type, + field_type, field_default, field_required, ) From 28fd9a1bd13464db13598ef52727318d3c2fafd3 Mon Sep 17 00:00:00 2001 From: fleonce Date: Wed, 27 Nov 2024 10:40:39 +0100 Subject: [PATCH 2/2] Fix mypy typing error in new dataclass call --- with_argparse/configure_argparse.py | 56 +++++++++++++++++++++++++++-- 1 file changed, 53 insertions(+), 3 deletions(-) diff --git a/with_argparse/configure_argparse.py b/with_argparse/configure_argparse.py index ec059ff..c9354e9 100644 --- a/with_argparse/configure_argparse.py +++ b/with_argparse/configure_argparse.py @@ -1,8 +1,10 @@ +import dataclasses import inspect import logging +import typing import warnings from argparse import ArgumentParser -from dataclasses import dataclass +from dataclasses import dataclass, _MISSING_TYPE from pathlib import Path from types import NoneType from typing import Any, Set, List, get_origin, get_args, Union, Literal, Optional, Sequence, TypeVar, Iterable, \ @@ -49,9 +51,12 @@ class WithArgparse: allow_custom: Mapping[str, Callable[[Any], Any]] allow_dispatch_custom: bool + func: Callable + dataclass: Optional[type] + def __init__( self, - func, + func_or_dataclass: Union[Callable, tuple[type, Callable]], aliases: Optional[Mapping[str, Sequence[str]]] = None, ignore_rename: Optional[set[str]] = None, ignore_keys: Optional[set[str]] = None, @@ -68,7 +73,16 @@ def __init__( self.allow_custom = allow_custom or dict() self.allow_dispatch_custom = True - self.func = func + if isinstance(func_or_dataclass, tuple): + if not inspect.isclass(func_or_dataclass[0]): + raise ValueError("First argument must be a type") + if not dataclasses.is_dataclass(func_or_dataclass[0]): + raise ValueError("First argument must be a dataclass") + self.dataclass = func_or_dataclass[0] + self.func = func_or_dataclass[1] + else: + self.func = func_or_dataclass + self.dataclass = None self.argparse = ArgumentParser() def _register_mapping(self): ... @@ -86,7 +100,43 @@ def _register_post_parse_type_conversion(self, key: str, func: Callable[[Any], A logger.debug(f"Registering post parse type conversion for {key}: {func.__name__} ({func})") self.post_parse_type_conversions[key].append(func) + def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]): + if args: + raise ValueError("Positional argument overrides are not supported, yet") + if kwargs: + raise ValueError("Keyword argument overrides are not supported, yet") + if self.dataclass is None: + raise ValueError("self.dataclass cannot be None") + + field_hints = typing.get_type_hints(self.dataclass) + for field in dataclasses.fields(self.dataclass): + field_required = isinstance(field.default, _MISSING_TYPE) + field_default = field.default if not field_required else None + field_type = field.type + if isinstance(field_type, str): + field_type = field_hints.get(field.name) + if not isinstance(field_type, type): + raise ValueError(f"Cannot determine type of {field.name}, got {field_type}") + self._setup_argument( + field.name, + field_type, + field_default, + field_required, + ) + + parsed_args = self.argparse.parse_args() + return self.func(self.dataclass(**parsed_args.__dict__)) + def call(self, args: Sequence[Any], kwargs: Mapping[str, Any]): + if self.dataclass is not None: + return self._call_dataclass(args, kwargs) + elif self.func is not None: + return self._call_func(args, kwargs) + else: + raise ValueError("self.dataclass and self.func cannot both be None") + + + def _call_func(self, args: Sequence[Any], kwargs: Mapping[str, Any]): info = inspect.getfullargspec(self.func) callable_args = info.args or [] callable_kwonly = info.kwonlyargs or []