__init__.py 1.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566
  1. # mypy: allow-untyped-defs
  2. import sys
  3. import types
  4. from typing import List
  5. import torch
  6. # This function should correspond to the enums present in c10/core/QEngine.h
  7. def _get_qengine_id(qengine: str) -> int:
  8. if qengine == "none" or qengine == "" or qengine is None:
  9. ret = 0
  10. elif qengine == "fbgemm":
  11. ret = 1
  12. elif qengine == "qnnpack":
  13. ret = 2
  14. elif qengine == "onednn":
  15. ret = 3
  16. elif qengine == "x86":
  17. ret = 4
  18. else:
  19. ret = -1
  20. raise RuntimeError(f"{qengine} is not a valid value for quantized engine")
  21. return ret
  22. # This function should correspond to the enums present in c10/core/QEngine.h
  23. def _get_qengine_str(qengine: int) -> str:
  24. all_engines = {0: "none", 1: "fbgemm", 2: "qnnpack", 3: "onednn", 4: "x86"}
  25. return all_engines.get(qengine, "*undefined")
  26. class _QEngineProp:
  27. def __get__(self, obj, objtype) -> str:
  28. return _get_qengine_str(torch._C._get_qengine())
  29. def __set__(self, obj, val: str) -> None:
  30. torch._C._set_qengine(_get_qengine_id(val))
  31. class _SupportedQEnginesProp:
  32. def __get__(self, obj, objtype) -> List[str]:
  33. qengines = torch._C._supported_qengines()
  34. return [_get_qengine_str(qe) for qe in qengines]
  35. def __set__(self, obj, val) -> None:
  36. raise RuntimeError("Assignment not supported")
  37. class QuantizedEngine(types.ModuleType):
  38. def __init__(self, m, name):
  39. super().__init__(name)
  40. self.m = m
  41. def __getattr__(self, attr):
  42. return self.m.__getattribute__(attr)
  43. engine = _QEngineProp()
  44. supported_engines = _SupportedQEnginesProp()
  45. # This is the sys.modules replacement trick, see
  46. # https://stackoverflow.com/questions/2447353/getattr-on-a-module/7668273#7668273
  47. sys.modules[__name__] = QuantizedEngine(sys.modules[__name__], __name__)
  48. engine: str
  49. supported_engines: List[str]