etcd_server.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247
  1. #!/usr/bin/env python3
  2. # mypy: allow-untyped-defs
  3. # Copyright (c) Facebook, Inc. and its affiliates.
  4. # All rights reserved.
  5. #
  6. # This source code is licensed under the BSD-style license found in the
  7. # LICENSE file in the root directory of this source tree.
  8. import atexit
  9. import logging
  10. import os
  11. import shlex
  12. import shutil
  13. import socket
  14. import subprocess
  15. import tempfile
  16. import time
  17. from typing import Optional, TextIO, Union
  18. try:
  19. import etcd # type: ignore[import]
  20. except ModuleNotFoundError:
  21. pass
  22. logger = logging.getLogger(__name__)
  23. def find_free_port():
  24. """
  25. Find a free port and binds a temporary socket to it so that the port can be "reserved" until used.
  26. .. note:: the returned socket must be closed before using the port,
  27. otherwise a ``address already in use`` error will happen.
  28. The socket should be held and closed as close to the
  29. consumer of the port as possible since otherwise, there
  30. is a greater chance of race-condition where a different
  31. process may see the port as being free and take it.
  32. Returns: a socket binded to the reserved free port
  33. Usage::
  34. sock = find_free_port()
  35. port = sock.getsockname()[1]
  36. sock.close()
  37. use_port(port)
  38. """
  39. addrs = socket.getaddrinfo(
  40. host="localhost", port=None, family=socket.AF_UNSPEC, type=socket.SOCK_STREAM
  41. )
  42. for addr in addrs:
  43. family, type, proto, _, _ = addr
  44. try:
  45. s = socket.socket(family, type, proto)
  46. s.bind(("localhost", 0))
  47. s.listen(0)
  48. return s
  49. except OSError as e:
  50. s.close() # type: ignore[possibly-undefined]
  51. print(f"Socket creation attempt failed: {e}")
  52. raise RuntimeError("Failed to create a socket")
  53. def stop_etcd(subprocess, data_dir: Optional[str] = None):
  54. if subprocess and subprocess.poll() is None:
  55. logger.info("stopping etcd server")
  56. subprocess.terminate()
  57. subprocess.wait()
  58. if data_dir:
  59. logger.info("deleting etcd data dir: %s", data_dir)
  60. shutil.rmtree(data_dir, ignore_errors=True)
  61. class EtcdServer:
  62. """
  63. .. note:: tested on etcd server v3.4.3.
  64. Starts and stops a local standalone etcd server on a random free
  65. port. Useful for single node, multi-worker launches or testing,
  66. where a sidecar etcd server is more convenient than having to
  67. separately setup an etcd server.
  68. This class registers a termination handler to shutdown the etcd
  69. subprocess on exit. This termination handler is NOT a substitute for
  70. calling the ``stop()`` method.
  71. The following fallback mechanism is used to find the etcd binary:
  72. 1. Uses env var TORCHELASTIC_ETCD_BINARY_PATH
  73. 2. Uses ``<this file root>/bin/etcd`` if one exists
  74. 3. Uses ``etcd`` from ``PATH``
  75. Usage
  76. ::
  77. server = EtcdServer("/usr/bin/etcd", 2379, "/tmp/default.etcd")
  78. server.start()
  79. client = server.get_client()
  80. # use client
  81. server.stop()
  82. Args:
  83. etcd_binary_path: path of etcd server binary (see above for fallback path)
  84. """
  85. def __init__(self, data_dir: Optional[str] = None):
  86. self._port = -1
  87. self._host = "localhost"
  88. root = os.path.dirname(__file__)
  89. default_etcd_bin = os.path.join(root, "bin/etcd")
  90. self._etcd_binary_path = os.environ.get(
  91. "TORCHELASTIC_ETCD_BINARY_PATH", default_etcd_bin
  92. )
  93. if not os.path.isfile(self._etcd_binary_path):
  94. self._etcd_binary_path = "etcd"
  95. self._base_data_dir = (
  96. data_dir if data_dir else tempfile.mkdtemp(prefix="torchelastic_etcd_data")
  97. )
  98. self._etcd_cmd = None
  99. self._etcd_proc: Optional[subprocess.Popen] = None
  100. def _get_etcd_server_process(self) -> subprocess.Popen:
  101. if not self._etcd_proc:
  102. raise RuntimeError(
  103. "No etcd server process started. Call etcd_server.start() first"
  104. )
  105. else:
  106. return self._etcd_proc
  107. def get_port(self) -> int:
  108. """Return the port the server is running on."""
  109. return self._port
  110. def get_host(self) -> str:
  111. """Return the host the server is running on."""
  112. return self._host
  113. def get_endpoint(self) -> str:
  114. """Return the etcd server endpoint (host:port)."""
  115. return f"{self._host}:{self._port}"
  116. def start(
  117. self,
  118. timeout: int = 60,
  119. num_retries: int = 3,
  120. stderr: Union[int, TextIO, None] = None,
  121. ) -> None:
  122. """
  123. Start the server, and waits for it to be ready. When this function returns the sever is ready to take requests.
  124. Args:
  125. timeout: time (in seconds) to wait for the server to be ready
  126. before giving up.
  127. num_retries: number of retries to start the server. Each retry
  128. will wait for max ``timeout`` before considering it as failed.
  129. stderr: the standard error file handle. Valid values are
  130. `subprocess.PIPE`, `subprocess.DEVNULL`, an existing file
  131. descriptor (a positive integer), an existing file object, and
  132. `None`.
  133. Raises:
  134. TimeoutError: if the server is not ready within the specified timeout
  135. """
  136. curr_retries = 0
  137. while True:
  138. try:
  139. data_dir = os.path.join(self._base_data_dir, str(curr_retries))
  140. os.makedirs(data_dir, exist_ok=True)
  141. return self._start(data_dir, timeout, stderr)
  142. except Exception as e:
  143. curr_retries += 1
  144. stop_etcd(self._etcd_proc)
  145. logger.warning(
  146. "Failed to start etcd server, got error: %s, retrying", str(e)
  147. )
  148. if curr_retries >= num_retries:
  149. shutil.rmtree(self._base_data_dir, ignore_errors=True)
  150. raise
  151. atexit.register(stop_etcd, self._etcd_proc, self._base_data_dir)
  152. def _start(
  153. self, data_dir: str, timeout: int = 60, stderr: Union[int, TextIO, None] = None
  154. ) -> None:
  155. sock = find_free_port()
  156. sock_peer = find_free_port()
  157. self._port = sock.getsockname()[1]
  158. peer_port = sock_peer.getsockname()[1]
  159. etcd_cmd = shlex.split(
  160. " ".join(
  161. [
  162. self._etcd_binary_path,
  163. "--enable-v2",
  164. "--data-dir",
  165. data_dir,
  166. "--listen-client-urls",
  167. f"http://{self._host}:{self._port}",
  168. "--advertise-client-urls",
  169. f"http://{self._host}:{self._port}",
  170. "--listen-peer-urls",
  171. f"http://{self._host}:{peer_port}",
  172. ]
  173. )
  174. )
  175. logger.info("Starting etcd server: [%s]", etcd_cmd)
  176. sock.close()
  177. sock_peer.close()
  178. self._etcd_proc = subprocess.Popen(etcd_cmd, close_fds=True, stderr=stderr)
  179. self._wait_for_ready(timeout)
  180. def get_client(self):
  181. """Return an etcd client object that can be used to make requests to this server."""
  182. return etcd.Client(
  183. host=self._host, port=self._port, version_prefix="/v2", read_timeout=10
  184. )
  185. def _wait_for_ready(self, timeout: int = 60) -> None:
  186. client = etcd.Client(
  187. host=f"{self._host}", port=self._port, version_prefix="/v2", read_timeout=5
  188. )
  189. max_time = time.time() + timeout
  190. while time.time() < max_time:
  191. if self._get_etcd_server_process().poll() is not None:
  192. # etcd server process finished
  193. exitcode = self._get_etcd_server_process().returncode
  194. raise RuntimeError(
  195. f"Etcd server process exited with the code: {exitcode}"
  196. )
  197. try:
  198. logger.info("etcd server ready. version: %s", client.version)
  199. return
  200. except Exception:
  201. time.sleep(1)
  202. raise TimeoutError("Timed out waiting for etcd server to be ready!")
  203. def stop(self) -> None:
  204. """Stop the server and cleans up auto generated resources (e.g. data dir)."""
  205. logger.info("EtcdServer stop method called")
  206. stop_etcd(self._etcd_proc, self._base_data_dir)