| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960 |
- # mypy: allow-untyped-defs
- from typing import Dict, Optional
- import torch
- from torch._logging import LazyString
- def lazy_format_graph_code(name, gm, maybe_id=None, **kwargs):
- """
- Returns a LazyString that formats the graph code.
- """
- def format_name():
- if maybe_id is not None:
- return f"{name} {maybe_id}"
- else:
- return name
- if "print_output" not in kwargs:
- kwargs["print_output"] = False
- return LazyString(
- lambda: _format_graph_code(
- f"===== {format_name()} =====\n",
- gm.forward.__code__.co_filename,
- gm.print_readable(**kwargs),
- )
- )
- def _format_graph_code(name, filename, graph_str):
- """
- Returns a string that formats the graph code.
- """
- return f"TRACED GRAPH\n {name} {filename} {graph_str}\n"
- def first_call_function_nn_module_stack(graph: torch.fx.Graph) -> Optional[Dict]:
- """
- Returns the nn_module_stack of the first call_function node.
- """
- for node in graph.nodes:
- if node.op == "call_function" and "nn_module_stack" in node.meta:
- return node.meta["nn_module_stack"]
- return None
- def get_node_context(node, num_nodes=2) -> str:
- """
- Returns a string of the last num_nodes nodes in the graph.
- """
- node_contexts = []
- cur = node
- for i in range(num_nodes):
- node_contexts.append(cur.format_node())
- if cur.op == "root":
- break
- cur = cur.prev
- return "\n".join(node_contexts[::-1])
|