Skip to content

Commit

Permalink
Support graphs which return get_attr nodes directly as output (pytorc…
Browse files Browse the repository at this point in the history
…h#107610)

Summary: Currently serializing graphs which return get_attr's directly as output fails. This diff adds support for that only in EXIR serializer while we still support unlifted params.

Test Plan: Added test case.

Differential Revision: D48258552

Pull Request resolved: pytorch#107610
Approved by: https://github.com/angelayi
  • Loading branch information
tarun292 authored and pytorchmergebot committed Aug 22, 2023
1 parent 979e706 commit e8278d6
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions torch/_export/serde/serialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -854,6 +854,14 @@ def deserialize_tensor_meta(
),
)

def deserialize_graph_output(self, output) -> torch.fx.Node:
if isinstance(output.value, TensorArgument):
return self.serialized_name_to_node[output.value.name]
elif isinstance(output.value, (SymIntArgument, SymBoolArgument)):
return self.serialized_name_to_node[output.value.as_name]
else:
raise SerializeError(f"Unable to deserialize output node {output}")

def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
# Handle the tensor metas.
for name, tensor_value in serialized_graph.tensor_values.items():
Expand Down Expand Up @@ -883,13 +891,7 @@ def deserialize_graph(self, serialized_graph: Graph) -> torch.fx.Graph:
# Outputs: convert to a single `output` node.
outputs = []
for output in serialized_graph.outputs:
if isinstance(output.value, TensorArgument):
outputs.append(self.serialized_name_to_node[output.value.name])
elif isinstance(output.value, (SymIntArgument, SymBoolArgument)):
outputs.append(self.serialized_name_to_node[output.value.as_name])
else:
raise SerializeError(f"Unable to deserialize output node {output}")

outputs.append(self.deserialize_graph_output(output))

output_node = self.graph.output(tuple(outputs))
output_node.meta["val"] = tuple(
Expand Down

0 comments on commit e8278d6

Please sign in to comment.