_utils.py 1.5 KB

1234567891011121314151617181920212223242526272829303132333435363738
  1. # mypy: allow-untyped-defs
  2. from contextlib import contextmanager
  3. from typing import cast
  4. import logging
  5. from . import api
  6. from . import TensorPipeAgent
  7. logger = logging.getLogger(__name__)
  8. @contextmanager
  9. def _group_membership_management(store, name, is_join):
  10. token_key = "RpcGroupManagementToken"
  11. join_or_leave = "join" if is_join else "leave"
  12. my_token = f"Token_for_{name}_{join_or_leave}"
  13. while True:
  14. # Retrieve token from store to signal start of rank join/leave critical section
  15. returned = store.compare_set(token_key, "", my_token).decode()
  16. if returned == my_token:
  17. # Yield to the function this context manager wraps
  18. yield
  19. # Finished, now exit and release token
  20. # Update from store to signal end of rank join/leave critical section
  21. store.set(token_key, "")
  22. # Other will wait for this token to be set before they execute
  23. store.set(my_token, "Done")
  24. break
  25. else:
  26. # Store will wait for the token to be released
  27. try:
  28. store.wait([returned])
  29. except RuntimeError:
  30. logger.error("Group membership token %s timed out waiting for %s to be released.", my_token, returned)
  31. raise
  32. def _update_group_membership(worker_info, my_devices, reverse_device_map, is_join):
  33. agent = cast(TensorPipeAgent, api._get_current_rpc_agent())
  34. ret = agent._update_group_membership(worker_info, my_devices, reverse_device_map, is_join)
  35. return ret