websockets.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195
  1. from __future__ import annotations
  2. import enum
  3. import json
  4. import typing
  5. from starlette.requests import HTTPConnection
  6. from starlette.responses import Response
  7. from starlette.types import Message, Receive, Scope, Send
  8. class WebSocketState(enum.Enum):
  9. CONNECTING = 0
  10. CONNECTED = 1
  11. DISCONNECTED = 2
  12. RESPONSE = 3
  13. class WebSocketDisconnect(Exception):
  14. def __init__(self, code: int = 1000, reason: str | None = None) -> None:
  15. self.code = code
  16. self.reason = reason or ""
  17. class WebSocket(HTTPConnection):
  18. def __init__(self, scope: Scope, receive: Receive, send: Send) -> None:
  19. super().__init__(scope)
  20. assert scope["type"] == "websocket"
  21. self._receive = receive
  22. self._send = send
  23. self.client_state = WebSocketState.CONNECTING
  24. self.application_state = WebSocketState.CONNECTING
  25. async def receive(self) -> Message:
  26. """
  27. Receive ASGI websocket messages, ensuring valid state transitions.
  28. """
  29. if self.client_state == WebSocketState.CONNECTING:
  30. message = await self._receive()
  31. message_type = message["type"]
  32. if message_type != "websocket.connect":
  33. raise RuntimeError(f'Expected ASGI message "websocket.connect", but got {message_type!r}')
  34. self.client_state = WebSocketState.CONNECTED
  35. return message
  36. elif self.client_state == WebSocketState.CONNECTED:
  37. message = await self._receive()
  38. message_type = message["type"]
  39. if message_type not in {"websocket.receive", "websocket.disconnect"}:
  40. raise RuntimeError(
  41. f'Expected ASGI message "websocket.receive" or "websocket.disconnect", but got {message_type!r}'
  42. )
  43. if message_type == "websocket.disconnect":
  44. self.client_state = WebSocketState.DISCONNECTED
  45. return message
  46. else:
  47. raise RuntimeError('Cannot call "receive" once a disconnect message has been received.')
  48. async def send(self, message: Message) -> None:
  49. """
  50. Send ASGI websocket messages, ensuring valid state transitions.
  51. """
  52. if self.application_state == WebSocketState.CONNECTING:
  53. message_type = message["type"]
  54. if message_type not in {"websocket.accept", "websocket.close", "websocket.http.response.start"}:
  55. raise RuntimeError(
  56. 'Expected ASGI message "websocket.accept", "websocket.close" or "websocket.http.response.start", '
  57. f"but got {message_type!r}"
  58. )
  59. if message_type == "websocket.close":
  60. self.application_state = WebSocketState.DISCONNECTED
  61. elif message_type == "websocket.http.response.start":
  62. self.application_state = WebSocketState.RESPONSE
  63. else:
  64. self.application_state = WebSocketState.CONNECTED
  65. await self._send(message)
  66. elif self.application_state == WebSocketState.CONNECTED:
  67. message_type = message["type"]
  68. if message_type not in {"websocket.send", "websocket.close"}:
  69. raise RuntimeError(
  70. f'Expected ASGI message "websocket.send" or "websocket.close", but got {message_type!r}'
  71. )
  72. if message_type == "websocket.close":
  73. self.application_state = WebSocketState.DISCONNECTED
  74. try:
  75. await self._send(message)
  76. except OSError:
  77. self.application_state = WebSocketState.DISCONNECTED
  78. raise WebSocketDisconnect(code=1006)
  79. elif self.application_state == WebSocketState.RESPONSE:
  80. message_type = message["type"]
  81. if message_type != "websocket.http.response.body":
  82. raise RuntimeError(f'Expected ASGI message "websocket.http.response.body", but got {message_type!r}')
  83. if not message.get("more_body", False):
  84. self.application_state = WebSocketState.DISCONNECTED
  85. await self._send(message)
  86. else:
  87. raise RuntimeError('Cannot call "send" once a close message has been sent.')
  88. async def accept(
  89. self,
  90. subprotocol: str | None = None,
  91. headers: typing.Iterable[tuple[bytes, bytes]] | None = None,
  92. ) -> None:
  93. headers = headers or []
  94. if self.client_state == WebSocketState.CONNECTING: # pragma: no branch
  95. # If we haven't yet seen the 'connect' message, then wait for it first.
  96. await self.receive()
  97. await self.send({"type": "websocket.accept", "subprotocol": subprotocol, "headers": headers})
  98. def _raise_on_disconnect(self, message: Message) -> None:
  99. if message["type"] == "websocket.disconnect":
  100. raise WebSocketDisconnect(message["code"], message.get("reason"))
  101. async def receive_text(self) -> str:
  102. if self.application_state != WebSocketState.CONNECTED:
  103. raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
  104. message = await self.receive()
  105. self._raise_on_disconnect(message)
  106. return typing.cast(str, message["text"])
  107. async def receive_bytes(self) -> bytes:
  108. if self.application_state != WebSocketState.CONNECTED:
  109. raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
  110. message = await self.receive()
  111. self._raise_on_disconnect(message)
  112. return typing.cast(bytes, message["bytes"])
  113. async def receive_json(self, mode: str = "text") -> typing.Any:
  114. if mode not in {"text", "binary"}:
  115. raise RuntimeError('The "mode" argument should be "text" or "binary".')
  116. if self.application_state != WebSocketState.CONNECTED:
  117. raise RuntimeError('WebSocket is not connected. Need to call "accept" first.')
  118. message = await self.receive()
  119. self._raise_on_disconnect(message)
  120. if mode == "text":
  121. text = message["text"]
  122. else:
  123. text = message["bytes"].decode("utf-8")
  124. return json.loads(text)
  125. async def iter_text(self) -> typing.AsyncIterator[str]:
  126. try:
  127. while True:
  128. yield await self.receive_text()
  129. except WebSocketDisconnect:
  130. pass
  131. async def iter_bytes(self) -> typing.AsyncIterator[bytes]:
  132. try:
  133. while True:
  134. yield await self.receive_bytes()
  135. except WebSocketDisconnect:
  136. pass
  137. async def iter_json(self) -> typing.AsyncIterator[typing.Any]:
  138. try:
  139. while True:
  140. yield await self.receive_json()
  141. except WebSocketDisconnect:
  142. pass
  143. async def send_text(self, data: str) -> None:
  144. await self.send({"type": "websocket.send", "text": data})
  145. async def send_bytes(self, data: bytes) -> None:
  146. await self.send({"type": "websocket.send", "bytes": data})
  147. async def send_json(self, data: typing.Any, mode: str = "text") -> None:
  148. if mode not in {"text", "binary"}:
  149. raise RuntimeError('The "mode" argument should be "text" or "binary".')
  150. text = json.dumps(data, separators=(",", ":"), ensure_ascii=False)
  151. if mode == "text":
  152. await self.send({"type": "websocket.send", "text": text})
  153. else:
  154. await self.send({"type": "websocket.send", "bytes": text.encode("utf-8")})
  155. async def close(self, code: int = 1000, reason: str | None = None) -> None:
  156. await self.send({"type": "websocket.close", "code": code, "reason": reason or ""})
  157. async def send_denial_response(self, response: Response) -> None:
  158. if "websocket.http.response" in self.scope.get("extensions", {}):
  159. await response(self.scope, self.receive, self.send)
  160. else:
  161. raise RuntimeError("The server doesn't support the Websocket Denial Response extension.")
  162. class WebSocketClose:
  163. def __init__(self, code: int = 1000, reason: str | None = None) -> None:
  164. self.code = code
  165. self.reason = reason or ""
  166. async def __call__(self, scope: Scope, receive: Receive, send: Send) -> None:
  167. await send({"type": "websocket.close", "code": self.code, "reason": self.reason})