log_extract.py 3.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. from typing import Any, List, Tuple, cast
  4. import random
  5. import torch
  6. import time
  7. from torch.utils.benchmark import Timer
  8. def extract_ir(filename: str) -> List[str]:
  9. BEGIN = "<GRAPH_EXPORT>"
  10. END = "</GRAPH_EXPORT>"
  11. pfx = None
  12. current = ""
  13. graphs = []
  14. with open(filename) as f:
  15. split_strs = f.read().split(BEGIN)
  16. for i, split_str in enumerate(split_strs):
  17. if i == 0:
  18. continue
  19. end_loc = split_str.find(END)
  20. if end_loc == -1:
  21. continue
  22. s = split_str[:end_loc]
  23. pfx = split_strs[i - 1].splitlines()[-1]
  24. lines = [x[len(pfx):] for x in s.splitlines(keepends=True)]
  25. graphs.append(''.join(lines))
  26. return graphs
  27. def make_tensor_from_type(inp_type: torch._C.TensorType):
  28. size = inp_type.sizes()
  29. stride = inp_type.strides()
  30. device = inp_type.device()
  31. dtype = inp_type.dtype()
  32. assert size is not None
  33. assert stride is not None
  34. assert device is not None
  35. assert dtype is not None
  36. return torch.empty_strided(size=size, stride=stride, device=device, dtype=dtype)
  37. def load_graph_and_inputs(ir: str) -> Tuple[Any, List[Any]]:
  38. graph = torch._C.parse_ir(ir, parse_tensor_constants=True)
  39. graph.makeMultiOutputIntoTuple()
  40. inputs = []
  41. for inp in graph.inputs():
  42. if isinstance(inp.type(), torch._C.FloatType):
  43. inputs.append(random.uniform(.1, 100))
  44. elif isinstance(inp.type(), torch._C.IntType):
  45. inputs.append(random.randint(1, 100))
  46. elif isinstance(inp.type(), torch._C.TensorType):
  47. tensorType = cast(torch._C.TensorType, inp.type())
  48. inputs.append(make_tensor_from_type(tensorType))
  49. elif isinstance(inp.type(), torch._C.BoolType):
  50. inputs.append(random.randint(0, 1) == 1)
  51. else:
  52. raise NotImplementedError(f"A default value is not implemented for type {inp.type()}")
  53. func = torch._C._create_function_from_graph("forward", graph)
  54. torch._C._jit_pass_erase_shape_information(func.graph)
  55. return (func, inputs)
  56. def time_cuda(fn, inputs, test_runs):
  57. t = Timer(stmt="fn(*inputs)", globals={"fn": fn, "inputs" : inputs})
  58. times = t.blocked_autorange()
  59. return times.median * 1000 # time in ms
  60. def time_cpu(fn, inputs, test_runs):
  61. s = time.perf_counter()
  62. for _ in range(test_runs):
  63. fn(*inputs)
  64. e = time.perf_counter()
  65. return (e - s) / test_runs * 1000 # time in ms
  66. def run_test(ir, inputs, *, warmup_runs=10, test_runs=20) -> float:
  67. graph, _ = load_graph_and_inputs(ir)
  68. for _ in range(warmup_runs):
  69. graph(*inputs)
  70. is_cpu = None
  71. for input in inputs:
  72. if isinstance(input, torch.Tensor):
  73. is_cpu = input.device.type == "cpu"
  74. break
  75. assert is_cpu is not None
  76. out = time_cpu(graph, inputs, test_runs) if is_cpu else time_cuda(graph, inputs, test_runs)
  77. return out
  78. @contextmanager
  79. def no_fuser(*args, **kwargs):
  80. old_optimize = torch._C._get_graph_executor_optimize(False)
  81. try:
  82. yield
  83. finally:
  84. torch._C._get_graph_executor_optimize(old_optimize)
  85. def run_baseline_no_fusion(ir, inputs) -> float:
  86. with no_fuser():
  87. return run_test(ir, inputs)
  88. def run_nnc(ir, inputs, dynamic) -> float:
  89. try:
  90. strat = [("DYNAMIC", 10)] if dynamic else [("STATIC", 10)]
  91. old_strat = torch.jit.set_fusion_strategy(strat)
  92. with torch.jit.fuser("fuser1"):
  93. return run_test(ir, inputs)
  94. finally:
  95. torch.jit.set_fusion_strategy(old_strat)
  96. def run_nvfuser(ir, inputs) -> float:
  97. with torch.jit.fuser("fuser2"):
  98. return run_test(ir, inputs)