Skip to content

Commit

Permalink
Adds type inference to argparse (#1)
Browse files Browse the repository at this point in the history
* infer boolean or numeric

* add typeinferencer class

* type inferencer

* default to infer string

* fuzzy read_conf

* remove fuzzy tag

* add support for iterable types for depth 1

* fix default type to str

* add basic iterable inference w/support for nested iterables, not strict

* add tests for inferring iterables & string-to-bool

* add infer types choices

* update argparse config signature

* set argparse config to default to fuzzy

* add int --> float test
  • Loading branch information
dhpitt authored Apr 27, 2024
1 parent a7baa4a commit 836382b
Show file tree
Hide file tree
Showing 3 changed files with 178 additions and 17 deletions.
39 changes: 30 additions & 9 deletions src/configmypy/argparse_config.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import argparse
from copy import deepcopy
from ctypes import ArgumentError
from .bunch import Bunch
from .utils import iter_nested_dict_flat
from .utils import update_nested_dict_from_flat

from .type_inference import TypeInferencer

class ArgparseConfig:
"""Read config from the command-line using argparse
Expand All @@ -12,8 +13,10 @@ class ArgparseConfig:
Parameters
----------
infer_types : bool, default is True
if True, use type(value) to indicate expected type
infer_types : False or literal['fuzzy', 'strict']
if False, use python's typecasting from type(value) to indicate expected type
if fuzzy, use custom type handling and allow all types to take values of None
if strict, use custom type handling but do not allow passing bool or None to other types.
overwrite_nested_config : bool, default is True
if True, users can set values at different levels of nesting
Expand All @@ -29,9 +32,12 @@ class ArgparseConfig:
**additional_config : dict
key, values to read from command-line and pass on to the next config
"""
def __init__(self, infer_types=True, overwrite_nested_config=False, **additional_config):
def __init__(self, infer_types="fuzzy", overwrite_nested_config=False, **additional_config):
self.additional_config = Bunch(additional_config)
self.infer_types = infer_types
if infer_types in [False, "fuzzy", "strict"]:
self.infer_types = infer_types
else:
raise ArgumentError("Error: infer_types only takes values False, \'fuzzy\', \'strict\'")
self.overwrite_nested_config = overwrite_nested_config

def read_conf(self, config=None, **additional_config):
Expand All @@ -44,6 +50,12 @@ def read_conf(self, config=None, **additional_config):
----------
config : dict, default is None
if not None, a dict config to update
infer_types: bool, default is False
if True, uses custom type handling
where all values, regardless of type,
may take the value None.
If false, uses default argparse
typecasting
additional_config : dict
Returns
Expand All @@ -61,10 +73,19 @@ def read_conf(self, config=None, **additional_config):

parser = argparse.ArgumentParser(description='Read the config from the commandline.')
for key, value in iter_nested_dict_flat(config, return_intermediate_keys=self.overwrite_nested_config):
if self.infer_types and value is not None:
parser.add_argument(f'--{key}', type=type(value), default=value)
else:
parser.add_argument(f'--{key}', default=value)
# smartly infer types if infer_types is turned on
# otherwise force default typecasting
if self.infer_types:
if self.infer_types == 'strict':
strict=True
elif self.infer_types == 'fuzzy':
strict=False

type_inferencer = TypeInferencer(orig_type=type(value), strict=strict)
else:
type_inferencer = type(value)

parser.add_argument(f'--{key}', type=type_inferencer, default=value)

args = parser.parse_args()
self.config = Bunch(args.__dict__)
Expand Down
29 changes: 21 additions & 8 deletions src/configmypy/tests/test_argparse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,33 +3,46 @@


TEST_CONFIG_DICT = {
'default': {'opt': {'optimizer': 'adam', 'lr': 0.1},
'data': {'dataset': 'ns', 'batch_size': 12}},
'test': {'opt': {'optimizer': 'SGD'}}
'default': {'opt': {'optimizer': 'adam', 'lr': 1, 'regularizer': True},
'data': {'dataset': 'ns', 'batch_size': 12, 'test_resolutions':[16,32]}},
'test': {'opt': {'optimizer': 'SGD'}}
}


def test_ArgparseConfig(monkeypatch):
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24'])
config = Bunch(TEST_CONFIG_DICT['default'])
parser = ArgparseConfig(infer_types=True)
parser = ArgparseConfig(infer_types='fuzzy')
args, kwargs = parser.read_conf(config)
config.data.batch_size = 24
assert config == args
assert kwargs == {}

monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_name', 'test'])

# test ability to infer None, int-->float and iterables
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\
'--data.test_resolutions', '[8, None]',\
'--opt.lr', '0.01',\
'--opt.regularizer', 'False',\
'--config_name', 'test'])
config = Bunch(TEST_CONFIG_DICT['default'])
parser = ArgparseConfig(infer_types=True, config_name=None)
parser = ArgparseConfig(infer_types='fuzzy', config_name=None)
args, kwargs = parser.read_conf(config)
config.data.batch_size = 24
assert args.opt.lr == 0.01 # ensure no casting to int
assert args.data.test_resolutions == [8, None]
assert not args.opt.regularizer #boolean False, not 'False'


args.data.test_resolutions = [16,32]
args.opt.regularizer = True
args.opt.lr = 1
assert config == args
assert kwargs == Bunch(dict(config_name='test'))

# Test overwriting entire nested argument
monkeypatch.setattr("sys.argv", ['test', '--data', '0', '--config_name', 'test'])
config = Bunch(TEST_CONFIG_DICT['default'])
parser = ArgparseConfig(infer_types=True, config_name=None, overwrite_nested_config=True)
parser = ArgparseConfig(infer_types='fuzzy', config_name=None, overwrite_nested_config=True)
args, kwargs = parser.read_conf(config)
config['data'] = 0
assert config == args
Expand Down
127 changes: 127 additions & 0 deletions src/configmypy/type_inference.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
# Infer types of argparse arguments intelligently

from argparse import ArgumentTypeError
from ast import literal_eval
from typing import Callable


def infer_boolean(var, strict: bool=True):
"""
accept argparse inputs of string, correctly
convert into bools. Without this behavior,
bool('False') becomes True by default.
Parameters
----------
var: input from parser.args
"""
if var.lower() == 'true':
return True
elif var.lower() == 'false':
return False
elif var.lower() == 'None':
return None
elif strict:
raise ArgumentTypeError()
else:
return str(var)

def infer_numeric(var, strict: bool=True):
# int if possible -> float -> NoneType -> Err
if var.isnumeric():
return int(var)
elif '.' in list(var):
decimal_left_right = var.split('.')
if len(decimal_left_right) == 2:
if sum([x.isnumeric() for x in decimal_left_right]) == 2: # True, True otherwise False
return float(var)
elif var.lower() == 'none':
return None
elif strict:
raise ArgumentTypeError()
else:
return str(var)

def infer_str(var, strict:bool=True):
"""
infer string values and handle fuzzy None and bool types (optional)
var: str input
strict: whether to allow "fuzzy" None types
"""
if strict:
return str(var)
elif var.lower() == 'None':
return None
else:
return str(var)

def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
# Use ast.literal_eval to parse the iterable tree,
# then use custom type handling to infer the inner types
raw_ast_iter = literal_eval(var)
if inner_type is not None:
return iterable_helper(raw_ast_iter, inner_type, strict)
else:
# currently argparse config cannot support inferring
# more granular types than list or tuple
return raw_ast_iter


def iterable_helper(var, inner_type: Callable, strict: bool=True):
"""
recursively loop through iterable and apply custom type
callables to each inner variable to conform to strictness
"""
if isinstance(var, list):
return [iterable_helper(x,inner_type, strict) for x in var]
elif isinstance(var, tuple):
return tuple([iterable_helper(x,inner_type, strict) for x in var])
else:
return inner_type(str(var),strict)


class TypeInferencer(object):
def __init__(self, orig_type: Callable, strict: bool=True):
"""
TypeInferencer mediates between argparse
and ArgparseConfig
orig_type: Callable type
type of original var from config
cannot be NoneType - defaults to infer_str
strict: bool, default True
whether to use type inferencing. If False,
default to simply applying default type converter.
This may cause issues for some types:
ex. if argparseConfig takes a boolean for arg 'x',
passing --x False into the command line will return a value
of True, since argparse works with strings and bool('False') = True.
strict: bool, default True
if True, raise ArgumentTypeError when the value cannot be cast to
the allowed types. Otherwise default to str.
"""
if orig_type == type(None):
self.orig_type = infer_str
else:
self.orig_type = orig_type
self.strict = strict

def __call__(self, var):
"""
Callable method passed to argparse's builtin Callable type argument.
var: original variable (any type)
"""
if self.orig_type == bool:
return infer_boolean(var, self.strict)
elif self.orig_type == float or self.orig_type == int:
return infer_numeric(var, self.strict)
elif self.orig_type == tuple or self.orig_type == list:
return infer_iterable(var, None, self.strict)
else:
if self.strict:
return infer_str(var)
else:
return self.orig_type(var)

0 comments on commit 836382b

Please sign in to comment.