_unflatten.py 741 B

123456789101112131415161718192021222324252627
  1. # mypy: allow-untyped-defs
  2. # Copyright (c) Meta Platforms, Inc. and affiliates
  3. from typing import Dict
  4. import torch
  5. from torch.export.unflatten import _ModuleFrame
  6. def _outline_submodules(orig_graph: torch.fx.Graph):
  7. # Create an empty GraphModule to hold the outlined modules
  8. new_module = torch.fx.GraphModule(torch.nn.Module(), torch.fx.Graph())
  9. seen_nodes: Dict[str, torch.fx.Node] = {}
  10. seen_modules: Dict[int, torch.nn.Module] = {}
  11. _ModuleFrame(
  12. orig_graph,
  13. tuple(orig_graph.nodes),
  14. seen_nodes,
  15. seen_modules,
  16. None,
  17. [""],
  18. "",
  19. {},
  20. module=new_module,
  21. ).run_outer()
  22. new_module.graph.lint()
  23. new_module.recompile()
  24. return new_module