_sockets.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194
  1. from __future__ import annotations
  2. import socket
  3. from abc import abstractmethod
  4. from collections.abc import Callable, Collection, Mapping
  5. from contextlib import AsyncExitStack
  6. from io import IOBase
  7. from ipaddress import IPv4Address, IPv6Address
  8. from socket import AddressFamily
  9. from types import TracebackType
  10. from typing import Any, Tuple, TypeVar, Union
  11. from .._core._typedattr import (
  12. TypedAttributeProvider,
  13. TypedAttributeSet,
  14. typed_attribute,
  15. )
  16. from ._streams import ByteStream, Listener, UnreliableObjectStream
  17. from ._tasks import TaskGroup
  18. IPAddressType = Union[str, IPv4Address, IPv6Address]
  19. IPSockAddrType = Tuple[str, int]
  20. SockAddrType = Union[IPSockAddrType, str]
  21. UDPPacketType = Tuple[bytes, IPSockAddrType]
  22. UNIXDatagramPacketType = Tuple[bytes, str]
  23. T_Retval = TypeVar("T_Retval")
  24. class _NullAsyncContextManager:
  25. async def __aenter__(self) -> None:
  26. pass
  27. async def __aexit__(
  28. self,
  29. exc_type: type[BaseException] | None,
  30. exc_val: BaseException | None,
  31. exc_tb: TracebackType | None,
  32. ) -> bool | None:
  33. return None
  34. class SocketAttribute(TypedAttributeSet):
  35. #: the address family of the underlying socket
  36. family: AddressFamily = typed_attribute()
  37. #: the local socket address of the underlying socket
  38. local_address: SockAddrType = typed_attribute()
  39. #: for IP addresses, the local port the underlying socket is bound to
  40. local_port: int = typed_attribute()
  41. #: the underlying stdlib socket object
  42. raw_socket: socket.socket = typed_attribute()
  43. #: the remote address the underlying socket is connected to
  44. remote_address: SockAddrType = typed_attribute()
  45. #: for IP addresses, the remote port the underlying socket is connected to
  46. remote_port: int = typed_attribute()
  47. class _SocketProvider(TypedAttributeProvider):
  48. @property
  49. def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
  50. from .._core._sockets import convert_ipv6_sockaddr as convert
  51. attributes: dict[Any, Callable[[], Any]] = {
  52. SocketAttribute.family: lambda: self._raw_socket.family,
  53. SocketAttribute.local_address: lambda: convert(
  54. self._raw_socket.getsockname()
  55. ),
  56. SocketAttribute.raw_socket: lambda: self._raw_socket,
  57. }
  58. try:
  59. peername: tuple[str, int] | None = convert(self._raw_socket.getpeername())
  60. except OSError:
  61. peername = None
  62. # Provide the remote address for connected sockets
  63. if peername is not None:
  64. attributes[SocketAttribute.remote_address] = lambda: peername
  65. # Provide local and remote ports for IP based sockets
  66. if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6):
  67. attributes[SocketAttribute.local_port] = (
  68. lambda: self._raw_socket.getsockname()[1]
  69. )
  70. if peername is not None:
  71. remote_port = peername[1]
  72. attributes[SocketAttribute.remote_port] = lambda: remote_port
  73. return attributes
  74. @property
  75. @abstractmethod
  76. def _raw_socket(self) -> socket.socket:
  77. pass
  78. class SocketStream(ByteStream, _SocketProvider):
  79. """
  80. Transports bytes over a socket.
  81. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  82. """
  83. class UNIXSocketStream(SocketStream):
  84. @abstractmethod
  85. async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None:
  86. """
  87. Send file descriptors along with a message to the peer.
  88. :param message: a non-empty bytestring
  89. :param fds: a collection of files (either numeric file descriptors or open file
  90. or socket objects)
  91. """
  92. @abstractmethod
  93. async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]:
  94. """
  95. Receive file descriptors along with a message from the peer.
  96. :param msglen: length of the message to expect from the peer
  97. :param maxfds: maximum number of file descriptors to expect from the peer
  98. :return: a tuple of (message, file descriptors)
  99. """
  100. class SocketListener(Listener[SocketStream], _SocketProvider):
  101. """
  102. Listens to incoming socket connections.
  103. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  104. """
  105. @abstractmethod
  106. async def accept(self) -> SocketStream:
  107. """Accept an incoming connection."""
  108. async def serve(
  109. self,
  110. handler: Callable[[SocketStream], Any],
  111. task_group: TaskGroup | None = None,
  112. ) -> None:
  113. from .. import create_task_group
  114. async with AsyncExitStack() as stack:
  115. if task_group is None:
  116. task_group = await stack.enter_async_context(create_task_group())
  117. while True:
  118. stream = await self.accept()
  119. task_group.start_soon(handler, stream)
  120. class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider):
  121. """
  122. Represents an unconnected UDP socket.
  123. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  124. """
  125. async def sendto(self, data: bytes, host: str, port: int) -> None:
  126. """
  127. Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).
  128. """
  129. return await self.send((data, (host, port)))
  130. class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider):
  131. """
  132. Represents an connected UDP socket.
  133. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  134. """
  135. class UNIXDatagramSocket(
  136. UnreliableObjectStream[UNIXDatagramPacketType], _SocketProvider
  137. ):
  138. """
  139. Represents an unconnected Unix datagram socket.
  140. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  141. """
  142. async def sendto(self, data: bytes, path: str) -> None:
  143. """Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, path))."""
  144. return await self.send((data, path))
  145. class ConnectedUNIXDatagramSocket(UnreliableObjectStream[bytes], _SocketProvider):
  146. """
  147. Represents a connected Unix datagram socket.
  148. Supports all relevant extra attributes from :class:`~SocketAttribute`.
  149. """