diff --git a/functorch/experimental/control_flow.py b/functorch/experimental/control_flow.py index cb6ff2e472490..e24fc61428200 100644 --- a/functorch/experimental/control_flow.py +++ b/functorch/experimental/control_flow.py @@ -1,4 +1,8 @@ from torch import cond # noqa: F401 from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401 -from ._map import map # noqa: F401 +from torch._higher_order_ops.map import ( # noqa: F401 + _stack_pytree, + _unstack_pytree, + map, +) diff --git a/test/functorch/test_control_flow.py b/test/functorch/test_control_flow.py index 193acee2848ed..e472c7fa59aff 100644 --- a/test/functorch/test_control_flow.py +++ b/test/functorch/test_control_flow.py @@ -44,7 +44,7 @@ def from_fun_old(t): return t def _fake_map(f, x, *args): - from functorch.experimental._map import _stack_pytree, _unstack_pytree + from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree x_pytrees = _unstack_pytree(x) zs = [] for xp in x_pytrees: diff --git a/torch/_export/pass_base.py b/torch/_export/pass_base.py index 8ea8cca37221e..5336344412ffb 100644 --- a/torch/_export/pass_base.py +++ b/torch/_export/pass_base.py @@ -5,8 +5,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union import torch -from functorch.experimental import _map -from functorch.experimental._map import _unstack_pytree +from functorch.experimental.control_flow import _unstack_pytree from torch import fx from torch._dispatch.python import enable_python_dispatcher from torch._export.pass_infra.node_metadata import NodeMetadata @@ -193,7 +192,7 @@ def call_function( elif target == torch.ops.higher_order.cond: pred, true_fn, false_fn, inputs = args return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta) - elif target == _map.map_impl: + elif target == torch.ops.higher_order.map_impl: f, num_args, *rest = args # type: ignore[assignment] return self.callback.call_map(f, num_args, list(rest), meta) # For other unregistered HigherOrderOps, just interpret them blindly @@ -361,7 +360,7 @@ def call_map( assert f_branch is not None return self._fx( "call_function", - _map.map_impl, + torch.ops.higher_order.map_impl, (f_branch.graph_module, num_args, *args), {}, meta, diff --git a/functorch/experimental/_map.py b/torch/_higher_order_ops/map.py similarity index 97% rename from functorch/experimental/_map.py rename to torch/_higher_order_ops/map.py index bc6fbe5ff4a9c..8e578c0e1cd07 100644 --- a/functorch/experimental/_map.py +++ b/torch/_higher_order_ops/map.py @@ -38,9 +38,9 @@ def __call__(self, xs, *args): map_impl = HigherOrderOperator("map_impl") dummy_aot_config = AOTConfig( - fw_compiler=None, - bw_compiler=None, - partition_fn=None, + fw_compiler=None, # type: ignore[arg-type] + bw_compiler=None, # type: ignore[arg-type] + partition_fn=None, # type: ignore[arg-type] decompositions={}, num_params_buffers=0, aot_id=0, @@ -185,7 +185,7 @@ def map_wrapper(f, xs, *args): out_spec = None def flat_fn(*flat_args): - xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec) + xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec) unflattened_out = f(xs, *flat_args[num_mapped_args:]) flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out) @@ -194,7 +194,7 @@ def flat_fn(*flat_args): return flat_out return pytree.tree_unflatten( - map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec + map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type] ) @@ -293,6 +293,7 @@ def _stack_pytree(pytrees): for pt in pytrees: flat_pt, out_spec = pytree.tree_flatten(pt) flat_out.append(flat_pt) + assert out_spec is not None b = zip(*flat_out) stacked_out = [] for leaves in b: @@ -302,7 +303,7 @@ def _stack_pytree(pytrees): # Backward graph can return None output when forward inputs doesn't require grad. # When we eagerly execute backward graph, we need to call _stack_pytree on its output, # therefore we need to deal with None output. - stacked_out.append(None) + stacked_out.append(None) # type: ignore[arg-type] else: raise RuntimeError(f"Cannot stack {leaves}.") return pytree.tree_unflatten(stacked_out, out_spec)