store.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. from datetime import timedelta
  9. from typing import List
  10. from contextlib import contextmanager
  11. _NUM_MEMBERS = "/num_members"
  12. _LAST_MEMBER_CHECKIN = "/last_member"
  13. __all__ = ["store_timeout", "get_all", "synchronize", "barrier"]
  14. @contextmanager
  15. def store_timeout(store, timeout: float):
  16. """
  17. This sets the timeout and then restores the old timeout when the context
  18. manager exits.
  19. Args:
  20. store: the store to set the timeout on
  21. timeout: the timeout to set
  22. """
  23. old_timeout = store.timeout
  24. store.set_timeout(timedelta(seconds=timeout))
  25. yield
  26. store.set_timeout(old_timeout)
  27. def get_all(store, rank: int, prefix: str, world_size: int):
  28. r"""
  29. Given a store and a prefix, the method goes through the array of keys
  30. of the following format: ``{prefix}{idx}``, where idx is in a range
  31. from 0 to size, and tries to retrieve the data.
  32. The Rank0 process waits at the end to make sure all other processes
  33. finished the procedure before exiting.
  34. Usage
  35. ::
  36. values = get_all(store, 'torchelastic/data', 3)
  37. value1 = values[0] # retrieves the data for key torchelastic/data0
  38. value2 = values[1] # retrieves the data for key torchelastic/data1
  39. value3 = values[2] # retrieves the data for key torchelastic/data2
  40. """
  41. data_arr = store.multi_get(
  42. [f"{prefix}{idx}" for idx in range(world_size)]
  43. )
  44. barrier_key = _barrier_nonblocking(
  45. store=store,
  46. world_size=world_size,
  47. key_prefix=f"{prefix}/finished",
  48. )
  49. if rank == 0:
  50. # Rank0 runs the TCPStore daemon, as a result it needs to exit last.
  51. # Otherwise, the barrier may timeout if rank0 process finished the work
  52. # before other processes finished `get_all` method
  53. store.get(barrier_key)
  54. return data_arr
  55. def synchronize(
  56. store,
  57. data: bytes,
  58. rank: int,
  59. world_size: int,
  60. key_prefix: str,
  61. timeout: float = 300,
  62. ) -> List[bytes]:
  63. """
  64. Synchronizes ``world_size`` agents between each other using the underlying c10d store.
  65. The ``data`` will be available on each of the agents.
  66. Note: The data on the path is not deleted, as a result there can be stale data if
  67. you use the same key_prefix twice.
  68. Time complexity: O(N) per worker, O(N^2) globally.
  69. """
  70. with store_timeout(store, timeout):
  71. store.set(f"{key_prefix}{rank}", data)
  72. agent_data = get_all(store, rank, key_prefix, world_size)
  73. return agent_data
  74. def _barrier_nonblocking(store, world_size: int, key_prefix: str) -> str:
  75. """
  76. Does all the non-blocking operations for a barrier and returns the final key
  77. that can be waited on.
  78. """
  79. num_members_key = key_prefix + _NUM_MEMBERS
  80. last_member_key = key_prefix + _LAST_MEMBER_CHECKIN
  81. idx = store.add(num_members_key, 1)
  82. if idx == world_size:
  83. store.set(last_member_key, "<val_ignored>")
  84. return last_member_key
  85. def barrier(
  86. store, world_size: int, key_prefix: str, barrier_timeout: float = 300
  87. ) -> None:
  88. """
  89. A global lock between agents. This will pause all workers until at least
  90. ``world_size`` workers respond.
  91. This uses a fast incrementing index to assign waiting ranks and a success
  92. flag set by the last worker.
  93. Time complexity: O(1) per worker, O(N) globally.
  94. Note: Since the data is not removed from the store, the barrier can be used
  95. once per unique ``key_prefix``.
  96. """
  97. with store_timeout(store, barrier_timeout):
  98. last_member_key = _barrier_nonblocking(store=store, world_size=world_size, key_prefix=key_prefix)
  99. store.get(last_member_key)