Skip to content

Commit

Permalink
Add error inputs to ModuleInfo (mirroring OpInfo) (pytorch#106325)
Browse files Browse the repository at this point in the history
Add infra for error inputs to ModuleInfos, migrate first few error inputs tests from test_nn.py (more to come!)

Pull Request resolved: pytorch#106325
Approved by: https://github.com/albanD
  • Loading branch information
mikaylagawarecki authored and pytorchmergebot committed Aug 1, 2023
1 parent 16df542 commit c9be60c
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 20 deletions.
22 changes: 21 additions & 1 deletion test/test_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from torch.testing._internal.common_cuda import with_tf32_off
from torch.testing._internal.common_device_type import (
instantiate_device_type_tests, onlyCPU, onlyCUDA, toleranceOverride, tol, skipMeta)
from torch.testing._internal.common_modules import module_db, modules, TrainEvalMode
from torch.testing._internal.common_modules import module_db, modules, ModuleErrorEnum, TrainEvalMode
from torch.testing._internal.common_utils import (
TestCase, run_tests, freeze_rng_state, mock_wrapper, get_tensors_from, gradcheck,
gradgradcheck)
Expand Down Expand Up @@ -766,6 +766,26 @@ def test_device_ctx_init(self, device, dtype, module_info, training):
assert_metadata_eq(self.assertEqual, p_meta, p_cpu)


@modules([module for module in module_db if module.module_error_inputs_func is not None])
def test_errors(self, device, dtype, module_info, training):
module_cls = module_info.module_cls
error_inputs = module_info.module_error_inputs_func(module_info, device=device, dtype=dtype,
requires_grad=False, training=training)
for error_input in error_inputs:
module_input = error_input.module_error_input
c_args, c_kwargs = module_input.constructor_input.args, module_input.constructor_input.kwargs
if error_input.error_on == ModuleErrorEnum.CONSTRUCTION_ERROR:
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
m = module_cls(*c_args, **c_kwargs)
elif error_input.error_on == ModuleErrorEnum.FORWARD_ERROR:
m = module_cls(*c_args, **c_kwargs)
fw_args, fw_kwargs = module_input.forward_input.args, module_input.forward_input.kwargs
with self.assertRaisesRegex(error_input.error_type, error_input.error_regex):
m(*fw_args, **fw_kwargs)
else:
raise NotImplementedError(f"Unknown error type {error_input.error_on}")


instantiate_device_type_tests(TestModule, globals(), allow_mps=True)

if __name__ == '__main__':
Expand Down
16 changes: 0 additions & 16 deletions test/test_nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2876,22 +2876,6 @@ def test_RNN_cell(self):

hx.sum().backward()

def test_RNN_cell_forward_input_size(self):
input = torch.randn(3, 11)
hx = torch.randn(3, 20)
for module in (nn.RNNCell, nn.GRUCell):
cell = module(10, 20)
self.assertRaises(Exception, lambda: cell(input, hx))

def test_RNN_cell_forward_hidden_size(self):
input = torch.randn(3, 10)
hx = torch.randn(3, 21)
cell_shared_param = (10, 20)
for cell in (nn.RNNCell(*cell_shared_param, nonlinearity="relu"),
nn.RNNCell(*cell_shared_param, nonlinearity="tanh"),
nn.GRUCell(*cell_shared_param)):
self.assertRaises(Exception, lambda: cell(input, hx))

def test_RNN_cell_forward_zero_hidden_size(self):
input = torch.randn(3, 10)
hx = torch.randn(3, 0)
Expand Down
76 changes: 73 additions & 3 deletions torch/testing/_internal/common_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,30 @@ def copy_reference_fn(m, *args, **kwargs):

self.reference_fn = copy_reference_fn

class ModuleErrorEnum(Enum):
""" Enumerates when error is raised when testing modules. """
CONSTRUCTION_ERROR = 0
FORWARD_ERROR = 1

class ErrorModuleInput:
"""
A ModuleInput that will cause the operation to throw an error plus information
about the resulting error.
"""

__slots__ = ["module_error_input", "error_on", "error_type", "error_regex"]

def __init__(self,
module_error_input,
*,
error_on=ModuleErrorEnum.CONSTRUCTION_ERROR,
error_type=RuntimeError,
error_regex):
self.module_error_input = module_error_input
self.error_on = error_on
self.error_type = error_type
self.error_regex = error_regex


class ModuleInfo:
""" Module information to be used in testing. """
Expand All @@ -182,6 +206,7 @@ def __init__(self,
module_memformat_affects_out=False, # whether converting module to channels last will generate
# channels last output
train_and_eval_differ=False, # whether the module has differing behavior between train and eval
module_error_inputs_func=None, # Function to generate module inputs that error
):
self.module_cls = module_cls
self.module_inputs_func = module_inputs_func
Expand All @@ -191,6 +216,7 @@ def __init__(self,
self.gradcheck_nondet_tol = gradcheck_nondet_tol
self.module_memformat_affects_out = module_memformat_affects_out
self.train_and_eval_differ = train_and_eval_differ
self.module_error_inputs_func = module_error_inputs_func

def get_decorators(self, test_class, test_name, device, dtype, param_kwargs):
result = [set_single_threaded_if_parallel_tbb]
Expand All @@ -210,6 +236,7 @@ def name(self):
def formatted_name(self):
return self.name.replace('.', '_')

# Start of module inputs functions.

def module_inputs_torch_nn_Linear(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
Expand Down Expand Up @@ -2206,9 +2233,6 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
),
]




# All these operators share similar issues on cuDNN and MIOpen
rnn_gru_lstm_module_info_decorators = (
# RuntimeError: Batching rule not implemented for aten::_cudnn_rnn_backward.
Expand Down Expand Up @@ -2243,6 +2267,50 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
)
)

# Start of module error inputs functions.

def module_error_inputs_torch_nn_RNN_GRU_Cell(module_info, device, dtype, requires_grad, training, **kwargs):
make_input = partial(make_tensor, device=device, dtype=dtype, requires_grad=requires_grad)
samples = [
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20),
forward_input=FunctionInput(make_input(3, 11), make_input(3, 20)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="input has inconsistent input_size: got 11 expected 10"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20, 'relu'),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
ErrorModuleInput(
ModuleInput(
constructor_input=FunctionInput(10, 20, 'tanh'),
forward_input=FunctionInput(make_input(3, 10), make_input(3, 21)),
),
error_on=ModuleErrorEnum.FORWARD_ERROR,
error_type=RuntimeError,
error_regex="hidden0 has inconsistent hidden_size: got 21, expected 20"
),
]
return samples

# Database of ModuleInfo entries in alphabetical order.
module_db: List[ModuleInfo] = [
ModuleInfo(torch.nn.AdaptiveAvgPool1d,
Expand Down Expand Up @@ -2912,11 +2980,13 @@ def module_inputs_torch_nn_ConstantPad3d(module_info, device, dtype, requires_gr
),
ModuleInfo(torch.nn.RNNCell,
module_inputs_func=partial(module_inputs_torch_nn_RNN_GRU_Cell, is_rnn=True),
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
ModuleInfo(torch.nn.GRUCell,
module_inputs_func=module_inputs_torch_nn_RNN_GRU_Cell,
module_error_inputs_func=module_error_inputs_torch_nn_RNN_GRU_Cell,
skips=(
DecorateInfo(skipIfMps, 'TestModule', dtypes=[torch.float64]),)
),
Expand Down

0 comments on commit c9be60c

Please sign in to comment.