common_mkldnn.py 2.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778
  1. # mypy: ignore-errors
  2. import contextlib
  3. import functools
  4. import inspect
  5. import torch
  6. # Test whether hardware BF32 math mode enabled. It is enabled only on:
  7. # - MKLDNN is available
  8. # - BF16 is supported by MKLDNN
  9. def bf32_is_not_fp32():
  10. if not torch.backends.mkldnn.is_available():
  11. return False
  12. if not torch.ops.mkldnn._is_mkldnn_bf16_supported():
  13. return False
  14. return True
  15. @contextlib.contextmanager
  16. def bf32_off():
  17. old_matmul_precision = torch.get_float32_matmul_precision()
  18. try:
  19. torch.set_float32_matmul_precision("highest")
  20. yield
  21. finally:
  22. torch.set_float32_matmul_precision(old_matmul_precision)
  23. @contextlib.contextmanager
  24. def bf32_on(self, bf32_precision=1e-5):
  25. old_matmul_precision = torch.get_float32_matmul_precision()
  26. old_precision = self.precision
  27. try:
  28. torch.set_float32_matmul_precision("medium")
  29. self.precision = bf32_precision
  30. yield
  31. finally:
  32. torch.set_float32_matmul_precision(old_matmul_precision)
  33. self.precision = old_precision
  34. # This is a wrapper that wraps a test to run this test twice, one with
  35. # allow_bf32=True, another with allow_bf32=False. When running with
  36. # allow_bf32=True, it will use reduced precision as specified by the
  37. # argument
  38. def bf32_on_and_off(bf32_precision=1e-5):
  39. def with_bf32_disabled(self, function_call):
  40. with bf32_off():
  41. function_call()
  42. def with_bf32_enabled(self, function_call):
  43. with bf32_on(self, bf32_precision):
  44. function_call()
  45. def wrapper(f):
  46. params = inspect.signature(f).parameters
  47. arg_names = tuple(params.keys())
  48. @functools.wraps(f)
  49. def wrapped(*args, **kwargs):
  50. for k, v in zip(arg_names, args):
  51. kwargs[k] = v
  52. cond = bf32_is_not_fp32()
  53. if "device" in kwargs:
  54. cond = cond and (torch.device(kwargs["device"]).type == "cpu")
  55. if "dtype" in kwargs:
  56. cond = cond and (kwargs["dtype"] == torch.float)
  57. if cond:
  58. with_bf32_disabled(kwargs["self"], lambda: f(**kwargs))
  59. with_bf32_enabled(kwargs["self"], lambda: f(**kwargs))
  60. else:
  61. f(**kwargs)
  62. return wrapped
  63. return wrapper