fuzzer.py 2.6 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586
  1. # mypy: allow-untyped-defs
  2. """Example of the Timer and Fuzzer APIs:
  3. $ python -m examples.fuzzer
  4. """
  5. import sys
  6. import torch.utils.benchmark as benchmark_utils
  7. def main():
  8. add_fuzzer = benchmark_utils.Fuzzer(
  9. parameters=[
  10. [
  11. benchmark_utils.FuzzedParameter(
  12. name=f"k{i}",
  13. minval=16,
  14. maxval=16 * 1024,
  15. distribution="loguniform",
  16. ) for i in range(3)
  17. ],
  18. benchmark_utils.FuzzedParameter(
  19. name="d",
  20. distribution={2: 0.6, 3: 0.4},
  21. ),
  22. ],
  23. tensors=[
  24. [
  25. benchmark_utils.FuzzedTensor(
  26. name=name,
  27. size=("k0", "k1", "k2"),
  28. dim_parameter="d",
  29. probability_contiguous=0.75,
  30. min_elements=64 * 1024,
  31. max_elements=128 * 1024,
  32. ) for name in ("x", "y")
  33. ],
  34. ],
  35. seed=0,
  36. )
  37. n = 250
  38. measurements = []
  39. for i, (tensors, tensor_properties, _) in enumerate(add_fuzzer.take(n=n)):
  40. x, x_order = tensors["x"], str(tensor_properties["x"]["order"])
  41. y, y_order = tensors["y"], str(tensor_properties["y"]["order"])
  42. shape = ", ".join(tuple(f'{i:>4}' for i in x.shape))
  43. description = "".join([
  44. f"{x.numel():>7} | {shape:<16} | ",
  45. f"{'contiguous' if x.is_contiguous() else x_order:<12} | ",
  46. f"{'contiguous' if y.is_contiguous() else y_order:<12} | ",
  47. ])
  48. timer = benchmark_utils.Timer(
  49. stmt="x + y",
  50. globals=tensors,
  51. description=description,
  52. )
  53. measurements.append(timer.blocked_autorange(min_run_time=0.1))
  54. measurements[-1].metadata = {"numel": x.numel()}
  55. print(f"\r{i + 1} / {n}", end="")
  56. sys.stdout.flush()
  57. print()
  58. # More string munging to make pretty output.
  59. print(f"Average attempts per valid config: {1. / (1. - add_fuzzer.rejection_rate):.1f}")
  60. def time_fn(m):
  61. return m.median / m.metadata["numel"]
  62. measurements.sort(key=time_fn)
  63. template = f"{{:>6}}{' ' * 19}Size Shape{' ' * 13}X order Y order\n{'-' * 80}"
  64. print(template.format("Best:"))
  65. for m in measurements[:15]:
  66. print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}")
  67. print("\n" + template.format("Worst:"))
  68. for m in measurements[-15:]:
  69. print(f"{time_fn(m) * 1e9:>4.1f} ns / element {m.description}")
  70. if __name__ == "__main__":
  71. main()