Skip to content

Commit

Permalink
[HigherOrderOp] Move _map.py to _higher_order_ops (pytorch#111152)
Browse files Browse the repository at this point in the history
Differential Revision: [D50332159](https://our.internmc.facebook.com/intern/diff/D50332159)
Pull Request resolved: pytorch#111152
Approved by: https://github.com/zou3519
  • Loading branch information
ydwu4 authored and pytorchmergebot committed Nov 16, 2023
1 parent 1364f84 commit 6703111
Show file tree
Hide file tree
Showing 4 changed files with 16 additions and 12 deletions.
6 changes: 5 additions & 1 deletion functorch/experimental/control_flow.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,8 @@
from torch import cond # noqa: F401
from torch._higher_order_ops.cond import UnsupportedAliasMutationException # noqa: F401

from ._map import map # noqa: F401
from torch._higher_order_ops.map import ( # noqa: F401
_stack_pytree,
_unstack_pytree,
map,
)
2 changes: 1 addition & 1 deletion test/functorch/test_control_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ def from_fun_old(t):
return t

def _fake_map(f, x, *args):
from functorch.experimental._map import _stack_pytree, _unstack_pytree
from functorch.experimental.control_flow import _stack_pytree, _unstack_pytree
x_pytrees = _unstack_pytree(x)
zs = []
for xp in x_pytrees:
Expand Down
7 changes: 3 additions & 4 deletions torch/_export/pass_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,7 @@
from typing import Any, Callable, Dict, List, Optional, Tuple, Union

import torch
from functorch.experimental import _map
from functorch.experimental._map import _unstack_pytree
from functorch.experimental.control_flow import _unstack_pytree
from torch import fx
from torch._dispatch.python import enable_python_dispatcher
from torch._export.pass_infra.node_metadata import NodeMetadata
Expand Down Expand Up @@ -193,7 +192,7 @@ def call_function(
elif target == torch.ops.higher_order.cond:
pred, true_fn, false_fn, inputs = args
return self.callback.call_cond(pred, true_fn, false_fn, inputs, meta)
elif target == _map.map_impl:
elif target == torch.ops.higher_order.map_impl:
f, num_args, *rest = args # type: ignore[assignment]
return self.callback.call_map(f, num_args, list(rest), meta)
# For other unregistered HigherOrderOps, just interpret them blindly
Expand Down Expand Up @@ -361,7 +360,7 @@ def call_map(
assert f_branch is not None
return self._fx(
"call_function",
_map.map_impl,
torch.ops.higher_order.map_impl,
(f_branch.graph_module, num_args, *args),
{},
meta,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,9 @@ def __call__(self, xs, *args):
map_impl = HigherOrderOperator("map_impl")

dummy_aot_config = AOTConfig(
fw_compiler=None,
bw_compiler=None,
partition_fn=None,
fw_compiler=None, # type: ignore[arg-type]
bw_compiler=None, # type: ignore[arg-type]
partition_fn=None, # type: ignore[arg-type]
decompositions={},
num_params_buffers=0,
aot_id=0,
Expand Down Expand Up @@ -185,7 +185,7 @@ def map_wrapper(f, xs, *args):
out_spec = None

def flat_fn(*flat_args):
xs = pytree.tree_unflatten(flat_args[:num_mapped_args], xs_spec)
xs = pytree.tree_unflatten(list(flat_args[:num_mapped_args]), xs_spec)
unflattened_out = f(xs, *flat_args[num_mapped_args:])
flat_out, tmp_out_spec = pytree.tree_flatten(unflattened_out)

Expand All @@ -194,7 +194,7 @@ def flat_fn(*flat_args):
return flat_out

return pytree.tree_unflatten(
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec
map_impl(flat_fn, num_mapped_args, *flat_xs, *args), out_spec # type: ignore[arg-type]
)


Expand Down Expand Up @@ -293,6 +293,7 @@ def _stack_pytree(pytrees):
for pt in pytrees:
flat_pt, out_spec = pytree.tree_flatten(pt)
flat_out.append(flat_pt)
assert out_spec is not None
b = zip(*flat_out)
stacked_out = []
for leaves in b:
Expand All @@ -302,7 +303,7 @@ def _stack_pytree(pytrees):
# Backward graph can return None output when forward inputs doesn't require grad.
# When we eagerly execute backward graph, we need to call _stack_pytree on its output,
# therefore we need to deal with None output.
stacked_out.append(None)
stacked_out.append(None) # type: ignore[arg-type]
else:
raise RuntimeError(f"Cannot stack {leaves}.")
return pytree.tree_unflatten(stacked_out, out_spec)
Expand Down

0 comments on commit 6703111

Please sign in to comment.