diff --git a/torch/_inductor/fx_passes/pre_grad.py b/torch/_inductor/fx_passes/pre_grad.py index 870a61c2dcb8f..de7c57176733b 100644 --- a/torch/_inductor/fx_passes/pre_grad.py +++ b/torch/_inductor/fx_passes/pre_grad.py @@ -49,7 +49,7 @@ def lazy_init(): from . import efficient_conv_bn_eval, split_cat # noqa: F401 # noqa: F401 if config.is_fbcode(): - from .fb import split_cat as split_cat_fb # type: ignore[import] # noqa: F401 + from . import fb # type: ignore[import] # noqa: F401 def pre_grad_passes(gm: torch.fx.GraphModule, example_inputs):