_utils.py 1.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960
  1. # mypy: allow-untyped-defs
  2. from typing import Dict, Optional
  3. import torch
  4. from torch._logging import LazyString
  5. def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
  6. """
  7. Returns a LazyString that formats the graph code.
  8. """
  9. def format_name():
  10. if maybe_id is not None:
  11. return f"{name} {maybe_id}"
  12. else:
  13. return name
  14. if "print_output" not in kwargs:
  15. kwargs["print_output"] = False
  16. return LazyString(
  17. lambda: _format_graph_code(
  18. f"===== {format_name()} =====\n",
  19. gm.forward.__code__.co_filename,
  20. gm.print_readable(**kwargs),
  21. )
  22. )
  23. def _format_graph_code(name, filename, graph_str):
  24. """
  25. Returns a string that formats the graph code.
  26. """
  27. return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
  28. def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
  29. """
  30. Returns the nn_module_stack of the first call_function node.
  31. """
  32. for node in graph.nodes:
  33. if node.op == "call_function" and "nn_module_stack" in node.meta:
  34. return node.meta["nn_module_stack"]
  35. return None
  36. def get_node_context(node, num_nodes=2) -> str:
  37. """
  38. Returns a string of the last num_nodes nodes in the graph.
  39. """
  40. node_contexts = []
  41. cur = node
  42. for i in range(num_nodes):
  43. node_contexts.append(cur.format_node())
  44. if cur.op == "root":
  45. break
  46. cur = cur.prev
  47. return "\n".join(node_contexts[::-1])