Skip to content

Commit

Permalink
fix infer iterable signature and clean up TypeInferencer API
Browse files Browse the repository at this point in the history
  • Loading branch information
dhpitt committed May 2, 2024
1 parent c4e6f5a commit 2642efc
Showing 1 changed file with 18 additions and 13 deletions.
31 changes: 18 additions & 13 deletions src/configmypy/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def infer_str(var, strict:bool=True):
else:
return str(var)

def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
def infer_iterable(var, strict: bool=True, inner_type: Callable=None):
# Use ast.literal_eval to parse the iterable tree,
# then use custom type handling to infer the inner types
raw_ast_iter = literal_eval(var)
Expand Down Expand Up @@ -109,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
"""
if orig_type == type(None):
self.orig_type = infer_str
self.orig_type = str
else:
self.orig_type = orig_type
self.strict = strict

if self.orig_type == str or self.orig_type == type(None):
self.type_callable = infer_str
elif self.orig_type == bool or self.orig_type == ScalarBoolean:
self.type_callable = infer_boolean
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
self.type_callable = infer_numeric
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
self.type_callable = infer_iterable
else:
self.type_callable = self.orig_type

def __call__(self, var):
"""
Callable method passed to argparse's builtin Callable type argument.
var: original variable (any type)
calls the proper type inferencer
"""
if self.orig_type == bool or self.orig_type == ScalarBoolean:
return infer_boolean(var, self.strict)
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
return infer_numeric(var, self.strict)
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
return infer_iterable(var, None, self.strict)
else:
if self.strict:
return infer_str(var)
else:
return self.orig_type(var)
return self.type_callable(var, self.strict)

def __repr__(self):
return f"TypeInferencer[{self.orig_type}]"

0 comments on commit 2642efc

Please sign in to comment.