diff --git a/src/configmypy/argparse_config.py b/src/configmypy/argparse_config.py index a740ccf..a8fe6bc 100644 --- a/src/configmypy/argparse_config.py +++ b/src/configmypy/argparse_config.py @@ -83,7 +83,7 @@ def read_conf(self, config=None, **additional_config): strict=True elif self.infer_types == 'fuzzy': strict=False - + print(f"creating inferencer for {key} of type {type(value)}") type_inferencer = TypeInferencer(orig_type=type(value), strict=strict) else: type_inferencer = type(value) diff --git a/src/configmypy/tests/test_argparse_config.py b/src/configmypy/tests/test_argparse_config.py index 99a5560..ff76097 100644 --- a/src/configmypy/tests/test_argparse_config.py +++ b/src/configmypy/tests/test_argparse_config.py @@ -10,12 +10,17 @@ def test_ArgparseConfig(monkeypatch): - monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24']) + monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\ + '--data.test_resolutions', '[16]']) config = Bunch(TEST_CONFIG_DICT['default']) parser = ArgparseConfig(infer_types='fuzzy') args, kwargs = parser.read_conf(config) + assert args.data.test_resolutions == [16] config.data.batch_size = 24 + config.data.test_resolutions = [16] assert config == args + # reset config data test_res + config.data.test_resolutions = [16,32] assert kwargs == {} # test ability to infer None, int-->float and iterables diff --git a/src/configmypy/tests/test_pipeline_config.py b/src/configmypy/tests/test_pipeline_config.py index a5c262d..f9e450f 100644 --- a/src/configmypy/tests/test_pipeline_config.py +++ b/src/configmypy/tests/test_pipeline_config.py @@ -12,6 +12,10 @@ data: dataset: 'ns' batch_size: 12 + test_batch_sizes: [16,32] + other: + int_to_float: 1 + float_to_none: 0.5 test: opt: @@ -20,7 +24,8 @@ TEST_CONFIG_DICT = { 'default': {'opt': {'optimizer': 'adam', 'lr': 0.1}, - 'data': {'dataset': 'ns', 'batch_size': 12}}, + 'data': {'dataset': 'ns', 'batch_size': 12, 'test_batch_sizes':[16,16]}, + 'other': {'int_to_float': 1, 'float_to_none': 0.5}}, 'test': {'opt': {'optimizer': 'SGD'}} } @@ -28,15 +33,22 @@ def test_ConfigPipeline(mocker, monkeypatch): """"Test for ConfigPipeline""" mocker.patch("builtins.open", mocker.mock_open(read_data=TEST_CONFIG_FILE)) - monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_file', 'config.yaml', '--config_name', 'test']) + monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', \ + '--data.test_batch_sizes', '[16,16]',\ + '--other.int_to_float', '0.05', \ + '--other.float_to_none', 'None', \ + '--config_file', 'config.yaml', \ + '--config_name', 'test']) true_config = Bunch(TEST_CONFIG_DICT['default']) true_config.data.batch_size = 24 true_config.opt.optimizer = 'SGD' + true_config.other.int_to_float = 0.05 + true_config.other.float_to_none = None pipe = ConfigPipeline([ YamlConfig('./config.yaml', config_name='default'), - ArgparseConfig(config_file=None, config_name=None), + ArgparseConfig(config_file=None, config_name=None, infer_types='fuzzy'), YamlConfig() ]) config = pipe.read_conf() diff --git a/src/configmypy/type_inference.py b/src/configmypy/type_inference.py index 7d6c5ed..3f0a7e6 100644 --- a/src/configmypy/type_inference.py +++ b/src/configmypy/type_inference.py @@ -4,6 +4,12 @@ from ast import literal_eval from typing import Callable +# import custom Yaml types to handle sequences +from ruamel.yaml import CommentedSeq, CommentedMap +from ruamel.yaml.scalarfloat import ScalarFloat +from ruamel.yaml.scalarint import ScalarInt +from ruamel.yaml.scalarbool import ScalarBoolean + def infer_boolean(var, strict: bool=True): """ @@ -56,7 +62,7 @@ def infer_str(var, strict:bool=True): else: return str(var) -def infer_iterable(var, inner_type: Callable=None, strict: bool=True): +def infer_iterable(var, strict: bool=True, inner_type: Callable=None): # Use ast.literal_eval to parse the iterable tree, # then use custom type handling to infer the inner types raw_ast_iter = literal_eval(var) @@ -103,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True): """ if orig_type == type(None): - self.orig_type = infer_str + self.orig_type = str else: self.orig_type = orig_type self.strict = strict + + if self.orig_type == str or self.orig_type == type(None): + self.type_callable = infer_str + elif self.orig_type == bool or self.orig_type == ScalarBoolean: + self.type_callable = infer_boolean + elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt: + self.type_callable = infer_numeric + elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq: + self.type_callable = infer_iterable + else: + self.type_callable = self.orig_type def __call__(self, var): """ Callable method passed to argparse's builtin Callable type argument. var: original variable (any type) + calls the proper type inferencer """ - if self.orig_type == bool: - return infer_boolean(var, self.strict) - elif self.orig_type == float or self.orig_type == int: - return infer_numeric(var, self.strict) - elif self.orig_type == tuple or self.orig_type == list: - return infer_iterable(var, None, self.strict) - else: - if self.strict: - return infer_str(var) - else: - return self.orig_type(var) + return self.type_callable(var, self.strict) + + def __repr__(self): + return f"TypeInferencer[{self.orig_type}]"