Skip to content

Commit

Permalink
Revert "[inductor] fix crash issue when input is a view tensor (pytor…
Browse files Browse the repository at this point in the history
…ch#90150)" (pytorch#94329)

Had to provide a merge conflict resolution due to conflicts with pytorch#94118

This was causing issues with internal tests that look similar to:
```
in clone_preserve_strides
    x.size(), x.stride(), x.storage_offset()
AttributeError: 'KeyedJaggedTensor' object has no attribute 'size'
```

See https://fburl.com/testinfra/nc0du2sp for more information

This reverts commit pytorch#90150

@jansel can you help @blzheng with re-landing this as a co-development diff?

Pull Request resolved: pytorch#94329
Approved by: https://github.com/jansel
  • Loading branch information
seemethere authored and pytorchmergebot committed Feb 7, 2023
1 parent 7b3217e commit 567e615
Show file tree
Hide file tree
Showing 9 changed files with 4 additions and 137 deletions.
69 changes: 0 additions & 69 deletions test/inductor/test_torchinductor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6337,75 +6337,6 @@ def fn(a):
if simdlen != 1:
assert metrics.generated_cpp_vec_kernel_count == 1

def test_inplace_unsqueeze(self):
@torch._dynamo.optimize("inductor")
def fn(a):
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
return unsqueeze_

for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
out = fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
assert out.equal(args[0])

def test_inplace_unsqueeze2(self):
@torch._dynamo.optimize("inductor")
def fn(a):
unsqueeze_ = torch.ops.aten.unsqueeze_.default(a, 0)
res = unsqueeze_ + 1
return res

for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
out = fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)
assert out.equal(args[0] + 1)

def test_inplace_unsqueeze3(self):
@torch._dynamo.optimize("inductor")
def fn(a):
torch.ops.aten.unsqueeze_.default(a, 0)
return 0

for dynamic_shapes in [True, False]:
args = [
(
(1, 1, 1, 12, 11, 3),
(396, 396, 396, 33, 3, 1),
torch.int64,
"cpu",
)
]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
torch._dynamo.config.dynamic_shapes = dynamic_shapes
with torch.no_grad():
fn(*args)
assert args[0].shape == (1, 1, 1, 1, 12, 11, 3)
assert args[0].stride() == (396, 396, 396, 396, 33, 3, 1)


if HAS_CUDA and not TEST_WITH_ASAN:
import triton
Expand Down
38 changes: 0 additions & 38 deletions torch/_dynamo/variables/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,44 +142,6 @@ def get_fake_examples(self):
assert isinstance(
self.fake_tensor, torch._subclasses.fake_tensor.FakeTensor
)
# For inplace ops changing the input's shape (unsqueeze_)
if not config.dynamic_shapes and (
self.fake_tensor.shape != self.example.shape
or self.fake_tensor.stride() != self.example.stride()
):
converter = torch._subclasses.fake_tensor.FakeTensorConverter()
self.fake_tensor = converter.from_real_tensor(
self.fake_tensor.fake_mode, self.example
)
elif config.dynamic_shapes:
(
size,
stride,
_,
) = self.fake_tensor.fake_mode.shape_env.create_symbolic_sizes_strides_storage_offset(
self.example, self.source
)
if (
torch.Size(size) != self.fake_tensor.shape
or tuple(stride) != self.fake_tensor.stride()
):
self.fake_tensor.fake_mode.converter = (
torch._subclasses.fake_tensor.FakeTensorConverter()
)
self.fake_tensor.fake_mode.shape_env = (
torch.fx.experimental.symbolic_shapes.ShapeEnv()
)
ignore_subclass = (
True
if type(self.example) in config.traceable_tensor_subclasses
else False
)
self.fake_tensor = self.fake_tensor.fake_mode.from_tensor(
self.example.clone(),
static_shapes=False,
ignore_subclass=ignore_subclass,
source=self.source,
)
return [self.fake_tensor]

def __len__(self):
Expand Down
5 changes: 1 addition & 4 deletions torch/_functorch/aot_autograd.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,11 +1049,8 @@ class AOTConfig:


def aot_dispatch_base(flat_fn, flat_args: List[Tensor], aot_config: AOTConfig):
# flat_args is used by make_fx and aot_config.fw_compiler
# clone flat_args to avoid flat_args shape changed by inplace ops (unsqueeze_)
tmp_flat_args = [torch._prims_common.clone_preserve_strides(x) for x in flat_args]
with enable_python_dispatcher():
fw_module = make_fx(flat_fn, aot_config.decompositions)(*tmp_flat_args)
fw_module = make_fx(flat_fn, aot_config.decompositions)(*flat_args)
if config.debug_graphs:
log.debug(f"====== Forward (only) graph {aot_config.aot_id} ======")
log.debug(fw_module.print_readable(print_output=False))
Expand Down
6 changes: 0 additions & 6 deletions torch/_inductor/codegen/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,10 +512,6 @@ def generate(self):
# these lines will be pointless
self.lines.pop()

for name, value in V.graph.graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
self.wrapper_call.writeline(value.data.codegen_reference_mutation())

# codegen allocations in two passes
planning_state = MemoryPlanningState()
for i in range(len(self.lines)):
Expand Down Expand Up @@ -585,8 +581,6 @@ def add_fake_input(name, shape, stride, device, dtype):
)

for name, value in V.graph.graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shape = [V.graph.sizevars.size_hint(x) for x in value.get_size()]
stride = [V.graph.sizevars.size_hint(x) for x in value.get_stride()]
add_fake_input(
Expand Down
2 changes: 0 additions & 2 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,8 +378,6 @@ def output(self, target, args, kwargs):
value.realize()
assert isinstance(value, TensorBox)
value = value.data
if isinstance(value, ir.ReinterpretView):
continue
assert isinstance(value, ir.StorageBox)
value_storage_box = value
value = value.data
Expand Down
8 changes: 0 additions & 8 deletions torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1473,14 +1473,6 @@ def codegen_reference(self):
return f"{as_strided}({self.get_name()}, {size}, {stride}, {offset})"
return f"{as_strided}({self.get_name()}, {size}, {stride})"

def codegen_reference_mutation(self):
size = V.graph.sizevars.codegen_shape_tuple(self.layout.size)
stride = V.graph.sizevars.codegen_shape_tuple(self.layout.stride)
offset = V.graph.sizevars.codegen_sizevar(self.layout.offset)
if offset != "0":
return f"{self.get_name()}.as_strided_({size}, {stride}, {offset})"
return f"{self.get_name()}.as_strided_({size}, {stride})"


class SliceView(View):
@classmethod
Expand Down
5 changes: 2 additions & 3 deletions torch/_inductor/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1016,9 +1016,8 @@ def free_buffers(self):
V.graph.wrapper_code.codegen_free(node.node)
elif name in V.graph.graph_inputs:
storage = V.graph.graph_inputs[name].data
if not isinstance(storage, ir.ReinterpretView):
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)
assert storage.is_input_buffer()
V.graph.wrapper_code.codegen_free(storage.data)

self.buffer_names_to_free.clear()

Expand Down
4 changes: 0 additions & 4 deletions torch/_inductor/sizevars.py
Original file line number Diff line number Diff line change
Expand Up @@ -458,8 +458,6 @@ def strideof(name):
needed = set(self.var_to_val.keys()) - set(self.replacements.keys())

for name, value in graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shapes = value.get_size()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand All @@ -470,8 +468,6 @@ def strideof(name):
)

for name, value in graph_inputs.items():
if isinstance(value.data, ir.ReinterpretView):
value = value.data.data
shapes = value.get_stride()
for dim, shape in enumerate(shapes):
shape = self.simplify(shape)
Expand Down
4 changes: 1 addition & 3 deletions torch/fx/passes/shape_prop.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,6 +182,4 @@ def propagate(self, *args):
Returns:
Any: The value returned from executing the Module
"""
# clone inputs to avoid side effects caused by inplace ops during run_node
new_args = [torch._prims_common.clone_preserve_strides(x) for x in args]
return super().run(*new_args)
return super().run(*args)

0 comments on commit 567e615

Please sign in to comment.