Skip to content

Commit

Permalink
[Inductor] Optimize finding users of buffers for mutation (pytorch#10…
Browse files Browse the repository at this point in the history
…5882)

Rather than visiting all nodes in the current environment to determine the users of a buffer, register the users of a buffer after node execution.

Pull Request resolved: pytorch#105882
Approved by: https://github.com/jansel
  • Loading branch information
mlazos authored and pytorchmergebot committed Jul 29, 2023
1 parent 9b94dcf commit 7b14a14
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
40 changes: 27 additions & 13 deletions torch/_inductor/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,9 @@
import re
import sys
import time
from collections import defaultdict
from contextlib import contextmanager
from typing import Dict, List, Optional, Set, Tuple
from typing import DefaultDict, Dict, List, Optional, Set, Tuple

import sympy

Expand Down Expand Up @@ -192,6 +193,7 @@ def __init__(
self.mutated_input_idxs: List[int] = []
self.unaligned_buffers: Set[str] = set()
self.name_to_buffer: Dict[str, ir.ComputedBuffer] = {}
self.name_to_users: DefaultDict[str, List[ir.IRNode]] = defaultdict(list)
self.creation_time = time.time()
self.name = "GraphLowering"
self.cpp_wrapper = cpp_wrapper
Expand Down Expand Up @@ -449,6 +451,24 @@ def register_list(self, buffer_names: List[str]):
self.lists[name] = buffer_names
return name

def register_users_of(self, node_output):
def register(value):
if isinstance(value, (list, tuple)):
for x in value:
register(x)
if isinstance(value, ir.IRNode):
if (
not hasattr(value, "data")
or not isinstance(value.data, ir.IRNode)
or not isinstance(value.data.data, ir.IRNode)
):
return

for read_name in value.get_read_names():
self.name_to_users[read_name].append(value)

register(node_output)

def mark_buffer_mutated(self, name: str):
"""
When a buffer is mutated we need to make sure all the reads to
Expand All @@ -457,19 +477,11 @@ def mark_buffer_mutated(self, name: str):
assert isinstance(name, str)
self.mutated_buffers.add(name)

def visit(value):
if isinstance(value, (list, tuple)):
return [visit(x) for x in value]
if isinstance(value, ir.IRNode):
if value.is_user_of(name):
value.realize()
return value
if name not in self.name_to_users:
return

for value in self.env.values():
try:
visit(value)
except Exception:
log.warning("error in mark_buffer_mutated", exc_info=True)
for user in self.name_to_users[name]:
user.realize()

def add_tensor_constant(self, data):
def allocate():
Expand Down Expand Up @@ -802,6 +814,8 @@ def run_node(self, n: torch.fx.Node):
if isinstance(result.data.data.inputs[0], ir.Buffer):
result.data.data.inputs[0].origin_node = n

self.register_users_of(result)

return result

def check_cpp_codegen_disabled(self):
Expand Down
6 changes: 5 additions & 1 deletion torch/_inductor/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,7 +293,11 @@ def str_helper(self, lines):
return f"{type(self).__name__}(\n{lines}\n)"

def is_user_of(self, name):
return any(name == dep.name for dep in self.get_reads())
return name in self.get_read_names()

@cache_on_self
def get_read_names(self):
return {dep.name for dep in self.get_reads()}

def get_numel(self):
return sympy_product(self.get_size())
Expand Down

0 comments on commit 7b14a14

Please sign in to comment.