device_context.py 661 B

1234567891011121314151617181920212223242526
  1. # mypy: allow-untyped-defs
  2. import threading
  3. from typing import Any, Dict
  4. import torch._C._lazy
  5. class DeviceContext:
  6. _CONTEXTS: Dict[str, Any] = dict()
  7. _CONTEXTS_LOCK = threading.Lock()
  8. def __init__(self, device):
  9. self.device = device
  10. def get_device_context(device=None):
  11. if device is None:
  12. device = torch._C._lazy._get_default_device_type()
  13. else:
  14. device = str(device)
  15. with DeviceContext._CONTEXTS_LOCK:
  16. devctx = DeviceContext._CONTEXTS.get(device, None)
  17. if devctx is None:
  18. devctx = DeviceContext(device)
  19. DeviceContext._CONTEXTS[device] = devctx
  20. return devctx