op_benchmark.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. # mypy: allow-untyped-defs
  2. """Example use of Timer and op fuzzers to measure kernel performance.
  3. $ python -m examples.op_benchmark
  4. """
  5. import numpy as np
  6. import torch
  7. from torch.utils.benchmark import Timer
  8. from torch.utils.benchmark.op_fuzzers.binary import BinaryOpFuzzer
  9. from torch.utils.benchmark.op_fuzzers.unary import UnaryOpFuzzer
  10. import operator
  11. _MEASURE_TIME = 1.0
  12. def assert_dicts_equal(dict_0, dict_1):
  13. """Builtin dict comparison will not compare numpy arrays.
  14. e.g.
  15. x = {"a": np.ones((2, 1))}
  16. x == x # Raises ValueError
  17. """
  18. assert set(dict_0.keys()) == set(dict_0.keys())
  19. assert all(np.all(v == dict_1[k]) for k, v in dict_0.items() if k != "dtype")
  20. def run(n, stmt, fuzzer_cls):
  21. float_iter = fuzzer_cls(seed=0, dtype=torch.float32).take(n)
  22. int_iter = fuzzer_cls(seed=0, dtype=torch.int32).take(n)
  23. raw_results = []
  24. for i, (float_values, int_values) in enumerate(zip(float_iter, int_iter)):
  25. float_tensors, float_tensor_params, float_params = float_values
  26. int_tensors, int_tensor_params, int_params = int_values
  27. # This benchmark assumes that the two fuzzers generate identically
  28. # sized and strided Tensors, since the same seed is used.
  29. assert_dicts_equal(float_params, int_params)
  30. assert_dicts_equal(float_tensor_params["x"], int_tensor_params["x"])
  31. float_measurement, int_measurement = (
  32. Timer(
  33. stmt,
  34. globals=tensors,
  35. ).blocked_autorange(min_run_time=_MEASURE_TIME)
  36. for tensors in (float_tensors, int_tensors)
  37. )
  38. descriptions = []
  39. for name in float_tensors:
  40. shape_str = "(" + ", ".join([
  41. f"2 ** {int(np.log2(i))}"
  42. if 2 ** int(np.log2(i)) == i and i > 1
  43. else str(i)
  44. for i in float_tensors[name].shape
  45. ]) + ")"
  46. order = float_tensor_params[name]["order"]
  47. order_str = ("" if all(order == np.arange(len(order))) else str(tuple(order)))
  48. steps = float_tensor_params[name]["steps"]
  49. steps_str = str(steps) if sum(steps) > len(steps) else ""
  50. descriptions.append((name, shape_str, order_str, steps_str))
  51. raw_results.append((float_measurement, int_measurement, descriptions))
  52. print(f"\r{i + 1} / {n}", end="")
  53. print()
  54. parsed_results, name_len, shape_len, order_len, steps_len = [], 0, 0, 0, 0
  55. for float_measurement, int_measurement, descriptions in raw_results:
  56. t_float = float_measurement.median * 1e6
  57. t_int = int_measurement.median * 1e6
  58. rel_diff = abs(t_float - t_int) / (t_float + t_int) * 2
  59. parsed_results.append((t_float, t_int, rel_diff, descriptions))
  60. for name, shape, order, steps in descriptions:
  61. name_len = max(name_len, len(name))
  62. shape_len = max(shape_len, len(shape))
  63. order_len = max(order_len, len(order))
  64. steps_len = max(steps_len, len(steps))
  65. parsed_results.sort(key=operator.itemgetter(2))
  66. print(f"stmt: {stmt}")
  67. print(f" diff faster{'':>17}{' ' * name_len} ", end="")
  68. print(f"{'shape'.ljust(shape_len)}{'':>16}{'order'.ljust(order_len)}", end="")
  69. print(f" steps\n{'-' * 100}")
  70. for results, spacer in [(parsed_results[:10], "..."), (parsed_results[-10:], "")]:
  71. for t_float, t_int, rel_diff, descriptions in results:
  72. time_str = [f"{rel_diff * 100:>4.1f}% {'int' if t_int < t_float else 'float':<20}"]
  73. time_str.extend(["".ljust(len(time_str[0])) for _ in descriptions[:-1]])
  74. for t_str, (name, shape, order, steps) in zip(time_str, descriptions):
  75. name = f"{name}:".ljust(name_len + 1)
  76. shape = shape.ljust(shape_len + 10)
  77. order = order.ljust(order_len)
  78. print(f"{t_str} {name} {shape}| {order} | {steps}")
  79. print(spacer)
  80. def main():
  81. run(n=100, stmt="torch.median(x, dim=0)", fuzzer_cls=UnaryOpFuzzer)
  82. run(n=100, stmt="torch.square(x)", fuzzer_cls=UnaryOpFuzzer)
  83. run(n=100, stmt="x + y", fuzzer_cls=BinaryOpFuzzer)
  84. if __name__ == "__main__":
  85. main()