Skip to content

Commit

Permalink
Fix argparse config type inferencer (#2)
Browse files Browse the repository at this point in the history
* update with new ruamel types

* fix bug w custom ruamel yaml types

* fix infer iterable signature and clean up TypeInferencer API
  • Loading branch information
dhpitt authored May 22, 2024
1 parent 46394af commit 85f9160
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 18 deletions.
2 changes: 1 addition & 1 deletion src/configmypy/argparse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ def read_conf(self, config=None, **additional_config):
strict=True
elif self.infer_types == 'fuzzy':
strict=False

print(f"creating inferencer for {key} of type {type(value)}")
type_inferencer = TypeInferencer(orig_type=type(value), strict=strict)
else:
type_inferencer = type(value)
Expand Down
7 changes: 6 additions & 1 deletion src/configmypy/tests/test_argparse_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@


def test_ArgparseConfig(monkeypatch):
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24'])
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24',\
'--data.test_resolutions', '[16]'])
config = Bunch(TEST_CONFIG_DICT['default'])
parser = ArgparseConfig(infer_types='fuzzy')
args, kwargs = parser.read_conf(config)
assert args.data.test_resolutions == [16]
config.data.batch_size = 24
config.data.test_resolutions = [16]
assert config == args
# reset config data test_res
config.data.test_resolutions = [16,32]
assert kwargs == {}

# test ability to infer None, int-->float and iterables
Expand Down
18 changes: 15 additions & 3 deletions src/configmypy/tests/test_pipeline_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
data:
dataset: 'ns'
batch_size: 12
test_batch_sizes: [16,32]
other:
int_to_float: 1
float_to_none: 0.5
test:
opt:
Expand All @@ -20,23 +24,31 @@

TEST_CONFIG_DICT = {
'default': {'opt': {'optimizer': 'adam', 'lr': 0.1},
'data': {'dataset': 'ns', 'batch_size': 12}},
'data': {'dataset': 'ns', 'batch_size': 12, 'test_batch_sizes':[16,16]},
'other': {'int_to_float': 1, 'float_to_none': 0.5}},
'test': {'opt': {'optimizer': 'SGD'}}
}


def test_ConfigPipeline(mocker, monkeypatch):
""""Test for ConfigPipeline"""
mocker.patch("builtins.open", mocker.mock_open(read_data=TEST_CONFIG_FILE))
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', '--config_file', 'config.yaml', '--config_name', 'test'])
monkeypatch.setattr("sys.argv", ['test', '--data.batch_size', '24', \
'--data.test_batch_sizes', '[16,16]',\
'--other.int_to_float', '0.05', \
'--other.float_to_none', 'None', \
'--config_file', 'config.yaml', \
'--config_name', 'test'])

true_config = Bunch(TEST_CONFIG_DICT['default'])
true_config.data.batch_size = 24
true_config.opt.optimizer = 'SGD'
true_config.other.int_to_float = 0.05
true_config.other.float_to_none = None

pipe = ConfigPipeline([
YamlConfig('./config.yaml', config_name='default'),
ArgparseConfig(config_file=None, config_name=None),
ArgparseConfig(config_file=None, config_name=None, infer_types='fuzzy'),
YamlConfig()
])
config = pipe.read_conf()
Expand Down
37 changes: 24 additions & 13 deletions src/configmypy/type_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@
from ast import literal_eval
from typing import Callable

# import custom Yaml types to handle sequences
from ruamel.yaml import CommentedSeq, CommentedMap
from ruamel.yaml.scalarfloat import ScalarFloat
from ruamel.yaml.scalarint import ScalarInt
from ruamel.yaml.scalarbool import ScalarBoolean


def infer_boolean(var, strict: bool=True):
"""
Expand Down Expand Up @@ -56,7 +62,7 @@ def infer_str(var, strict:bool=True):
else:
return str(var)

def infer_iterable(var, inner_type: Callable=None, strict: bool=True):
def infer_iterable(var, strict: bool=True, inner_type: Callable=None):
# Use ast.literal_eval to parse the iterable tree,
# then use custom type handling to infer the inner types
raw_ast_iter = literal_eval(var)
Expand Down Expand Up @@ -103,25 +109,30 @@ def __init__(self, orig_type: Callable, strict: bool=True):
"""
if orig_type == type(None):
self.orig_type = infer_str
self.orig_type = str
else:
self.orig_type = orig_type
self.strict = strict

if self.orig_type == str or self.orig_type == type(None):
self.type_callable = infer_str
elif self.orig_type == bool or self.orig_type == ScalarBoolean:
self.type_callable = infer_boolean
elif self.orig_type == float or self.orig_type == int or self.orig_type == ScalarFloat or self.orig_type == ScalarInt:
self.type_callable = infer_numeric
elif self.orig_type == tuple or self.orig_type == list or self.orig_type == CommentedMap or self.orig_type == CommentedSeq:
self.type_callable = infer_iterable
else:
self.type_callable = self.orig_type

def __call__(self, var):
"""
Callable method passed to argparse's builtin Callable type argument.
var: original variable (any type)
calls the proper type inferencer
"""
if self.orig_type == bool:
return infer_boolean(var, self.strict)
elif self.orig_type == float or self.orig_type == int:
return infer_numeric(var, self.strict)
elif self.orig_type == tuple or self.orig_type == list:
return infer_iterable(var, None, self.strict)
else:
if self.strict:
return infer_str(var)
else:
return self.orig_type(var)
return self.type_callable(var, self.strict)

def __repr__(self):
return f"TypeInferencer[{self.orig_type}]"

0 comments on commit 85f9160

Please sign in to comment.