compare.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899
  1. # mypy: allow-untyped-defs
  2. """Example of Timer and Compare APIs:
  3. $ python -m examples.compare
  4. """
  5. import pickle
  6. import sys
  7. import time
  8. import torch
  9. import torch.utils.benchmark as benchmark_utils
  10. class FauxTorch:
  11. """Emulate different versions of pytorch.
  12. In normal circumstances this would be done with multiple processes
  13. writing serialized measurements, but this simplifies that model to
  14. make the example clearer.
  15. """
  16. def __init__(self, real_torch, extra_ns_per_element):
  17. self._real_torch = real_torch
  18. self._extra_ns_per_element = extra_ns_per_element
  19. def extra_overhead(self, result):
  20. # time.sleep has a ~65 us overhead, so only fake a
  21. # per-element overhead if numel is large enough.
  22. numel = int(result.numel())
  23. if numel > 5000:
  24. time.sleep(numel * self._extra_ns_per_element * 1e-9)
  25. return result
  26. def add(self, *args, **kwargs):
  27. return self.extra_overhead(self._real_torch.add(*args, **kwargs))
  28. def mul(self, *args, **kwargs):
  29. return self.extra_overhead(self._real_torch.mul(*args, **kwargs))
  30. def cat(self, *args, **kwargs):
  31. return self.extra_overhead(self._real_torch.cat(*args, **kwargs))
  32. def matmul(self, *args, **kwargs):
  33. return self.extra_overhead(self._real_torch.matmul(*args, **kwargs))
  34. def main():
  35. tasks = [
  36. ("add", "add", "torch.add(x, y)"),
  37. ("add", "add (extra +0)", "torch.add(x, y + zero)"),
  38. ]
  39. serialized_results = []
  40. repeats = 2
  41. timers = [
  42. benchmark_utils.Timer(
  43. stmt=stmt,
  44. globals={
  45. "torch": torch if branch == "master" else FauxTorch(torch, overhead_ns),
  46. "x": torch.ones((size, 4)),
  47. "y": torch.ones((1, 4)),
  48. "zero": torch.zeros(()),
  49. },
  50. label=label,
  51. sub_label=sub_label,
  52. description=f"size: {size}",
  53. env=branch,
  54. num_threads=num_threads,
  55. )
  56. for branch, overhead_ns in [("master", None), ("my_branch", 1), ("severe_regression", 5)]
  57. for label, sub_label, stmt in tasks
  58. for size in [1, 10, 100, 1000, 10000, 50000]
  59. for num_threads in [1, 4]
  60. ]
  61. for i, timer in enumerate(timers * repeats):
  62. serialized_results.append(pickle.dumps(
  63. timer.blocked_autorange(min_run_time=0.05)
  64. ))
  65. print(f"\r{i + 1} / {len(timers) * repeats}", end="")
  66. sys.stdout.flush()
  67. print()
  68. comparison = benchmark_utils.Compare([
  69. pickle.loads(i) for i in serialized_results
  70. ])
  71. print("== Unformatted " + "=" * 80 + "\n" + "/" * 95 + "\n")
  72. comparison.print()
  73. print("== Formatted " + "=" * 80 + "\n" + "/" * 93 + "\n")
  74. comparison.trim_significant_figures()
  75. comparison.colorize()
  76. comparison.print()
  77. if __name__ == "__main__":
  78. main()