diff --git a/with_argparse/configure_argparse.py b/with_argparse/configure_argparse.py index 7a75fc4..e3d437b 100644 --- a/with_argparse/configure_argparse.py +++ b/with_argparse/configure_argparse.py @@ -116,9 +116,10 @@ def _call_dataclass(self, args: Sequence[Any], kwargs: Mapping[str, Any]): if isinstance(field_type, str): field_type = typing.cast(type, field_hints.get(field.name)) - known_types = {type, Literal, GenericAlias} - if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types: - raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}") + # known_types = {type, Literal, GenericAlias, UnionType} + # if type(field_type) not in known_types and typing.get_origin(field_type) not in known_types: + # raise ValueError(f"Cannot determine type of {field.name}, got {field_type} {type(field_type)}") + # raises on typing.Optional[typing.Literal['epsilon', 'v_prediction']] self._setup_argument( field.name, field_type,