__init__.py 1.8 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556
  1. # mypy: allow-untyped-defs
  2. import threading
  3. import torch._C._lazy
  4. from torch.utils._pytree import tree_flatten, tree_unflatten
  5. from .closure import add_step_closure, run_step_closures
  6. def mark_step(device: str = "", wait=False):
  7. """Triggers a mark step, which amounts to
  8. - collecting a group of 'live' lazy tensors to index into the compilation cache
  9. (lowering/compiling their IR graphs if not cached)
  10. - kicking off execution of the compiled function
  11. - (optionally, wait=True) waiting for cpu-side execution to complete (does not sync the accelerator)
  12. """
  13. # TODO(whc) expand this to include backend hooks and align with XLA backend needs
  14. torch._C._lazy._mark_step(device, [], wait=wait)
  15. run_step_closures()
  16. def wait_device_ops(devices=None):
  17. """Waits for all the async operations on the given devices to complete.
  18. Args:
  19. devices (string..., optional): The devices whose async ops need to be waited
  20. for. If empty, all the local devices will be waited for.
  21. """
  22. if devices is None:
  23. devices = []
  24. torch._C._lazy._wait_device_ops(devices=devices)
  25. def sync_multi(tensors, devices):
  26. """
  27. Sync the list of lazy tensors so there IR get lowered for the activate backend
  28. and the compiled computation graph get cached.
  29. """
  30. torch._C._lazy._sync_multi(tensors, devices)
  31. def get_tensor_id(tensor):
  32. """Return a unique id of the lazy tensor maintained by LTC"""
  33. return torch._C._lazy._get_tensor_id(tensor)
  34. def to_cpu(tensors, devices=None):
  35. devices = devices or ["lazy"]
  36. flattened, spec = tree_flatten(tensors)
  37. sync_multi(flattened, devices)
  38. return tree_unflatten([t.to("cpu") for t in flattened], spec)
  39. def save(tensors, *args, **kwargs):
  40. torch.save(to_cpu(tensors), *args, **kwargs)