| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156 |
- # mypy: allow-untyped-defs
- import dataclasses
- import os
- from typing import Any, List
- import torch
- from .utils import print_once
- @dataclasses.dataclass
- class ProfileMetrics:
- microseconds: float = 0.0
- operators: int = 0
- fusions: int = 0
- graphs: int = 0
- def __iadd__(self, other: "ProfileMetrics"):
- self.microseconds += other.microseconds
- self.operators += other.operators
- self.fusions += other.fusions
- return self
- def __add__(self, other: "ProfileMetrics"):
- assert isinstance(other, ProfileMetrics)
- return ProfileMetrics(
- self.microseconds + other.microseconds,
- self.operators + other.operators,
- self.fusions + other.fusions,
- )
- def __truediv__(self, other):
- if isinstance(other, int):
- other = ProfileMetrics(other, other, other)
- return ProfileMetrics(
- self.microseconds / max(1, other.microseconds),
- self.operators / max(1, other.operators),
- self.fusions / max(1, other.fusions),
- )
- def __str__(self):
- return f"{self.operators:4.0%} ops {self.microseconds:4.0%} time"
- def tocsv(self):
- return [self.operators, self.microseconds]
- class ProfileResult:
- def __init__(self, captured, total, unique_graphs):
- self.captured: ProfileMetrics = captured or ProfileMetrics()
- self.total: ProfileMetrics = total or ProfileMetrics()
- self.unique_graphs: int = unique_graphs
- def __iadd__(self, other: "ProfileResult"):
- self.captured += other.captured
- self.total += other.total
- self.unique_graphs += other.unique_graphs
- return self
- def percent(self):
- return self.captured / self.total
- def __str__(self):
- return (
- f"{self.unique_graphs:2} graphs {self.captured.graphs:2} graph calls "
- f"{self.captured.operators:4}/{self.total.operators:4} = "
- + str(self.percent())
- )
- def tocsv(self):
- return [
- self.unique_graphs,
- self.captured.graphs,
- self.captured.operators,
- self.total.operators,
- ] + self.percent().tocsv()
- def should_print_missing():
- return os.environ.get("TORCHDYNAMO_PRINT_MISSING") == "1"
- def print_missing(stack):
- if any("/torch/autograd/profiler.py" in x for x in stack):
- return
- stack = [
- x for x in stack if ("<built-in" not in x and "site-packages/torch/" not in x)
- ]
- print_once("MISSING", " >> ".join(stack[-3:]))
- class Profiler:
- unique_graphs = 0
- def __init__(self):
- self.prof = torch.profiler.profile(
- activities=[torch.profiler.ProfilerActivity.CPU],
- with_stack=should_print_missing(),
- )
- def results(self):
- captured_regions = 0
- captured_ops = 0
- captured_microseconds = 0
- total_ops = 0
- total_microseconds = 0
- last_op_end_time = -1
- captured_region_end_time = -1
- events = sorted(self.prof.events(), key=lambda x: x.time_range.start)
- for e in events:
- if e.name == "TORCHDYNAMO":
- captured_region_end_time = e.time_range.end
- captured_regions += 1
- # ignore `handle = torch.zeros(1)` in record_function.__init__()
- total_ops -= 1
- elif e.time_range.start >= last_op_end_time:
- last_op_end_time = e.time_range.end
- if e.time_range.end <= captured_region_end_time:
- captured_ops += 1
- captured_microseconds += e.time_range.elapsed_us()
- elif should_print_missing():
- print_missing(e.stack)
- total_ops += 1
- total_microseconds += e.time_range.elapsed_us()
- else:
- pass # ops recursively called from other ops (ignored)
- unique_graphs = Profiler.unique_graphs
- Profiler.unique_graphs = 0
- # we counted one extra op that is part of the profiler setup code
- total_ops -= 1
- return ProfileResult(
- captured=ProfileMetrics(
- microseconds=captured_microseconds,
- operators=captured_ops,
- fusions=captured_ops - captured_regions,
- graphs=captured_regions,
- ),
- total=ProfileMetrics(
- microseconds=total_microseconds,
- operators=total_ops,
- fusions=total_ops - 1,
- ),
- unique_graphs=unique_graphs,
- )
- def fx_insert_profiling(gm: torch.fx.GraphModule, example_inputs: List[Any]):
- def _wrapped(*args):
- with torch.profiler.record_function("TORCHDYNAMO"):
- return gm.forward(*args)
- Profiler.unique_graphs += 1
- return _wrapped
|