From 836382b597e89f5895ee02e7829789b522969c3a Mon Sep 17 00:00:00 2001 From: David Pitt Date: Sat, 27 Apr 2024 10:08:48 -0700 Subject: [PATCH] Adds type inference to argparse (#1) * infer boolean or numeric * add typeinferencer class * type inferencer * default to infer string * fuzzy read_conf * remove fuzzy tag * add support for iterable types for depth 1 * fix default type to str * add basic iterable inference w/support for nested iterables, not strict * add tests for inferring iterables & string-to-bool * add infer types choices * update argparse config signature * set argparse config to default to fuzzy * add int --> float test --- src/configmypy/argparse_config.py | 39 ++++-- src/configmypy/tests/test_argparse_config.py | 29 +++-- src/configmypy/type_inference.py | 127 +++++++++++++++++++ 3 files changed, 178 insertions(+), 17 deletions(-) create mode 100644 src/configmypy/type_inference.py diff --git a/src/configmypy/argparse_config.py b/src/configmypy/argparse_config.py index ff62812..bba683c 100644 --- a/src/configmypy/argparse_config.py +++ b/src/configmypy/argparse_config.py @@ -1,9 +1,10 @@ import argparse from copy import deepcopy +from ctypes import ArgumentError from .bunch import Bunch from .utils import iter_nested_dict_flat from .utils import update_nested_dict_from_flat - +from .type_inference import TypeInferencer class ArgparseConfig: """Read config from the command-line using argparse @@ -12,8 +13,10 @@ class ArgparseConfig: Parameters ---------- - infer_types : bool, default is True - if True, use type(value) to indicate expected type + infer_types : False or literal['fuzzy', 'strict'] + if False, use python's typecasting from type(value) to indicate expected type + if fuzzy, use custom type handling and allow all types to take values of None + if strict, use custom type handling but do not allow passing bool or None to other types. overwrite_nested_config : bool, default is True if True, users can set values at different levels of nesting @@ -29,9 +32,12 @@ class ArgparseConfig: **additional_config : dict key, values to read from command-line and pass on to the next config """ - def __init__(self, infer_types=True, overwrite_nested_config=False, **additional_config): + def __init__(self, infer_types="fuzzy", overwrite_nested_config=False, **additional_config): self.additional_config = Bunch(additional_config) - self.infer_types = infer_types + if infer_types in [False, "fuzzy", "strict"]: + self.infer_types = infer_types + else: + raise ArgumentError("Error: infer_types only takes values False, \'fuzzy\', \'strict\'") self.overwrite_nested_config = overwrite_nested_config def read_conf(self, config=None, **additional_config): @@ -44,6 +50,12 @@ def read_conf(self, config=None, **additional_config): ---------- config : dict, default is None if not None, a dict config to update + infer_types: bool, default is False + if True, uses custom type handling + where all values, regardless of type, + may take the value None. + If false, uses default argparse + typecasting additional_config : dict Returns @@ -61,10 +73,19 @@ def read_conf(self, config=None, **additional_config): parser = argparse.ArgumentParser(description='Read the config from the commandline.') for key, value in iter_nested_dict_flat(config, return_intermediate_keys=self.overwrite_nested_config): - if self.infer_types and value is not None: - parser.add_argument(f'--{key}', type=type(value), default=value) - else: - parser.add_argument(f'--{key}', default=value) + # smartly infer types if infer_types is turned on + # otherwise force default typecasting + if self.infer_types: + if self.infer_types == 'strict': + strict=True + elif self.infer_types == 'fuzzy': + strict=False + + type_inferencer = TypeInferencer(orig_type=type(value), strict=strict) + else: + type_inferencer = type(value) + + parser.add_argument(f'--{key}', type=type_inferencer, default=value) args = parser.parse_args() self.config = Bunch(args.__dict__) diff --git a/src/configmypy/tests/test_argparse_config.py b/src/configmypy/tests/test_argparse_config.py index ce97d51..99a5560 100644 --- a/src/configmypy/tests/test_argparse_config.py +++ b/src/configmypy/tests/test_argparse_config.py @@ -3,33 +3,46 @@ TEST_CONFIG_DICT = { - 'default': {'opt': {'optimizer': 'adam', 'lr': 0.1}, - 'data': {'dataset': 'ns', 'batch_size': 12}}, - 'test': {'opt': {'optimizer': 'SGD'}} + 'default': {'opt': {'optimizer': 'adam', 'lr': 1, 'regularizer': True}, + 'data': {'dataset': 'ns', 'batch_size': 12, 'test_resolutions':[16,32]}}, + 'test': {'opt': {'optimizer': 'SGD'}} } def test_ArgparseConfig(monkeypatch): monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24']) config = Bunch(TEST_CONFIG_DICT['default']) - parser = ArgparseConfig(infer_types=True) + parser = ArgparseConfig(infer_types='fuzzy') args, kwargs = parser.read_conf(config) config.data.batch_size = 24 assert config == args assert kwargs == {} - - monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_name', 'test']) + + # test ability to infer None, int-->float and iterables + monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\ + '--data.test_resolutions', '[8, None]',\ + '--opt.lr', '0.01',\ + '--opt.regularizer', 'False',\ + '--config_name', 'test']) config = Bunch(TEST_CONFIG_DICT['default']) - parser = ArgparseConfig(infer_types=True, config_name=None) + parser = ArgparseConfig(infer_types='fuzzy', config_name=None) args, kwargs = parser.read_conf(config) config.data.batch_size = 24 + assert args.opt.lr == 0.01 # ensure no casting to int + assert args.data.test_resolutions == [8, None] + assert not args.opt.regularizer #boolean False, not 'False' + + + args.data.test_resolutions = [16,32] + args.opt.regularizer = True + args.opt.lr = 1 assert config == args assert kwargs == Bunch(dict(config_name='test')) # Test overwriting entire nested argument monkeypatch.setattr("sys.argv", ['test', '--data', '0', '--config_name', 'test']) config = Bunch(TEST_CONFIG_DICT['default']) - parser = ArgparseConfig(infer_types=True, config_name=None, overwrite_nested_config=True) + parser = ArgparseConfig(infer_types='fuzzy', config_name=None, overwrite_nested_config=True) args, kwargs = parser.read_conf(config) config['data'] = 0 assert config == args diff --git a/src/configmypy/type_inference.py b/src/configmypy/type_inference.py new file mode 100644 index 0000000..7d6c5ed --- /dev/null +++ b/src/configmypy/type_inference.py @@ -0,0 +1,127 @@ +# Infer types of argparse arguments intelligently + +from argparse import ArgumentTypeError +from ast import literal_eval +from typing import Callable + + +def infer_boolean(var, strict: bool=True): + """ + accept argparse inputs of string, correctly + convert into bools. Without this behavior, + bool('False') becomes True by default. + + Parameters + ---------- + var: input from parser.args + """ + if var.lower() == 'true': + return True + elif var.lower() == 'false': + return False + elif var.lower() == 'None': + return None + elif strict: + raise ArgumentTypeError() + else: + return str(var) + +def infer_numeric(var, strict: bool=True): + # int if possible -> float -> NoneType -> Err + if var.isnumeric(): + return int(var) + elif '.' in list(var): + decimal_left_right = var.split('.') + if len(decimal_left_right) == 2: + if sum([x.isnumeric() for x in decimal_left_right]) == 2: # True, True otherwise False + return float(var) + elif var.lower() == 'none': + return None + elif strict: + raise ArgumentTypeError() + else: + return str(var) + +def infer_str(var, strict:bool=True): + """ + infer string values and handle fuzzy None and bool types (optional) + + var: str input + strict: whether to allow "fuzzy" None types + """ + if strict: + return str(var) + elif var.lower() == 'None': + return None + else: + return str(var) + +def infer_iterable(var, inner_type: Callable=None, strict: bool=True): + # Use ast.literal_eval to parse the iterable tree, + # then use custom type handling to infer the inner types + raw_ast_iter = literal_eval(var) + if inner_type is not None: + return iterable_helper(raw_ast_iter, inner_type, strict) + else: + # currently argparse config cannot support inferring + # more granular types than list or tuple + return raw_ast_iter + + +def iterable_helper(var, inner_type: Callable, strict: bool=True): + """ + recursively loop through iterable and apply custom type + callables to each inner variable to conform to strictness + """ + if isinstance(var, list): + return [iterable_helper(x,inner_type, strict) for x in var] + elif isinstance(var, tuple): + return tuple([iterable_helper(x,inner_type, strict) for x in var]) + else: + return inner_type(str(var),strict) + + +class TypeInferencer(object): + def __init__(self, orig_type: Callable, strict: bool=True): + """ + TypeInferencer mediates between argparse + and ArgparseConfig + + orig_type: Callable type + type of original var from config + cannot be NoneType - defaults to infer_str + strict: bool, default True + whether to use type inferencing. If False, + default to simply applying default type converter. + This may cause issues for some types: + ex. if argparseConfig takes a boolean for arg 'x', + passing --x False into the command line will return a value + of True, since argparse works with strings and bool('False') = True. + strict: bool, default True + if True, raise ArgumentTypeError when the value cannot be cast to + the allowed types. Otherwise default to str. + + """ + if orig_type == type(None): + self.orig_type = infer_str + else: + self.orig_type = orig_type + self.strict = strict + + def __call__(self, var): + """ + Callable method passed to argparse's builtin Callable type argument. + var: original variable (any type) + + """ + if self.orig_type == bool: + return infer_boolean(var, self.strict) + elif self.orig_type == float or self.orig_type == int: + return infer_numeric(var, self.strict) + elif self.orig_type == tuple or self.orig_type == list: + return infer_iterable(var, None, self.strict) + else: + if self.strict: + return infer_str(var) + else: + return self.orig_type(var)