diff --git a/src/configmypy/type_inference.py b/src/configmypy/type_inference.py index ee07ce5..3f0a7e6 100644 --- a/src/configmypy/type_inference.py +++ b/src/configmypy/type_inference.py @@ -62,7 +62,7 @@ def infer_str(var, strict:bool=True): else: return str(var) -def infer_iterable(var, inner_type: Callable=None, strict: bool=True): +def infer_iterable(var, strict: bool=True, inner_type: Callable=None): # Use ast.literal_eval to parse the iterable tree, # then use custom type handling to infer the inner types raw_ast_iter = literal_eval(var) @@ -109,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True): """ if orig_type == type(None): - self.orig_type = infer_str + self.orig_type = str else: self.orig_type = orig_type self.strict = strict + + if self.orig_type == str or self.orig_type == type(None): + self.type_callable = infer_str + elif self.orig_type == bool or self.orig_type == ScalarBoolean: + self.type_callable = infer_boolean + elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt: + self.type_callable = infer_numeric + elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq: + self.type_callable = infer_iterable + else: + self.type_callable = self.orig_type def __call__(self, var): """ Callable method passed to argparse's builtin Callable type argument. var: original variable (any type) + calls the proper type inferencer """ - if self.orig_type == bool or self.orig_type == ScalarBoolean: - return infer_boolean(var, self.strict) - elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt: - return infer_numeric(var, self.strict) - elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq: - return infer_iterable(var, None, self.strict) - else: - if self.strict: - return infer_str(var) - else: - return self.orig_type(var) + return self.type_callable(var, self.strict) + + def __repr__(self): + return f"TypeInferencer[{self.orig_type}]"