Skip to content

Commit

Permalink
Add support for GET_YIELD_FROM_ITER, YIELD_FROM, SEND (pytorch#106986)
Browse files Browse the repository at this point in the history
Pull Request resolved: pytorch#106986
Approved by: https://github.com/jansel
  • Loading branch information
voznesenskym authored and pytorchmergebot committed Aug 19, 2023
1 parent 4f3284e commit 02c2b75
Show file tree
Hide file tree
Showing 2 changed files with 108 additions and 0 deletions.
61 changes: 61 additions & 0 deletions test/dynamo/test_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6632,6 +6632,67 @@ def fn(q, a, b):
self.assertEqual(counter.frame_count, 1)
self.assertTrue(isinstance(compiled, torch.Tensor))

def test_yield_from(self):
def yield_from_fn(t_list, k):
def yield_from_gen(l):
l2 = [t * k for t in l]
yield from l2

return [t * k for t in yield_from_gen(t_list)]

t_list = [torch.randn([2, 3])] * 3
multiplier = torch.tensor([10])
eager = yield_from_fn(t_list, 2)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(yield_from_fn)(t_list, 2)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)

def test_yield_gen_and_from(self):
def populate_and_multiply_sequence(n, multiplier):
# Inline generator
def tensor_generator():
for i in range(n):
yield torch.tensor([i])

# Use 'yield from' to iterate over tensors and multiply
t_list = [tensor * multiplier for tensor in tensor_generator()]

def yield_from_gen():
yield from t_list

return [t for t in yield_from_gen()]

multiplier = torch.tensor([10])
eager = populate_and_multiply_sequence(5, multiplier)
counter = CompileCounter()
compiled = torch._dynamo.optimize(counter)(populate_and_multiply_sequence)(
5, multiplier
)
self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 1)

def test_yield_send_to_subgenerator_graph_break(self):
def subgenerator(tensor):
multiplier = yield
yield tensor * multiplier

def main_generator(t_list):
for tensor in t_list:
subgen = subgenerator(tensor)
next(subgen)
yield from subgen.send(torch.tensor([10]))

t_list = [torch.tensor([i]) for i in range(5)]
eager = list(main_generator(t_list))

counter = CompileCounter()
compiled_fn = torch._dynamo.optimize(counter)(main_generator)
compiled = list(compiled_fn(t_list))

self.assertEqual(eager, compiled)
self.assertEqual(counter.frame_count, 0)


class TestTracer(JitTestCase):
def test_jit_save(self):
Expand Down
47 changes: 47 additions & 0 deletions torch/_dynamo/symbolic_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -2444,3 +2444,50 @@ def YIELD_VALUE(self, inst: Instruction):
self.generated_items.append(self.pop())
# TODO(jansel): figure out why this is needed, it isn't in the docs for YIELD_VALUE
self.push(ConstantVariable(None))

def GET_YIELD_FROM_ITER(self, inst):
tos = self.stack[-1]
if not isinstance(tos, ListIteratorVariable):
self.pop()
res = BuiltinVariable(iter).call_function(self, [tos], {})
self.push(res)
return self.YIELD_FROM(inst)

def YIELD_FROM(self, inst):
while True:
tos = self.stack[-1]
if isinstance(tos, ConstantVariable) and tos.value is None:
self.pop()
return
if isinstance(tos, ListIteratorVariable):
self.output.guards.update(tos.guards)
try:
val, next_iter = tos.next_variables()
self.replace_all(tos, next_iter)
self.push(val)
# TODO(voz): Unclear if we need the push None in YIELD_VALUE?
self.YIELD_VALUE(inst)
self.pop()
self.push(next_iter)
except StopIteration:
return
else:
unimplemented(f"YIELD_FROM {typestr(tos)}")

def SEND(self, inst):
assert len(self.stack) >= 2
val = self.pop()
tos = self.stack[-1]
if isinstance(tos, ListIteratorVariable):
if isinstance(val, ConstantVariable) and val.value is None:
self.push(val)
self.instruction_pointer = self.indexof[inst.target]
else:
# invoke send
# Unreachable code - if you hit this, you are implementing generator support and have
# lifted the `unimplemented("generator")` in frame conversion. This codepath handles
# subgenerator and lines up with this line in Python 3.11
# https://github.com/python/cpython/blob/3.11/Python/ceval.c#L2597
unimplemented("Unreachable sub-generator code")
else:
unimplemented(f"SEND {typestr(tos)}")

0 comments on commit 02c2b75

Please sign in to comment.