queue.py 1.4 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243
  1. # mypy: allow-untyped-defs
  2. import io
  3. import multiprocessing.queues
  4. import pickle
  5. from multiprocessing.reduction import ForkingPickler
  6. class ConnectionWrapper:
  7. """Proxy class for _multiprocessing.Connection which uses ForkingPickler for object serialization."""
  8. def __init__(self, conn):
  9. self.conn = conn
  10. def send(self, obj):
  11. buf = io.BytesIO()
  12. ForkingPickler(buf, pickle.HIGHEST_PROTOCOL).dump(obj)
  13. self.send_bytes(buf.getvalue())
  14. def recv(self):
  15. buf = self.recv_bytes()
  16. return pickle.loads(buf)
  17. def __getattr__(self, name):
  18. if "conn" in self.__dict__:
  19. return getattr(self.conn, name)
  20. raise AttributeError(f"'{type(self).__name__}' object has no attribute 'conn'")
  21. class Queue(multiprocessing.queues.Queue):
  22. def __init__(self, *args, **kwargs):
  23. super().__init__(*args, **kwargs)
  24. self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
  25. self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
  26. self._send = self._writer.send
  27. self._recv = self._reader.recv
  28. class SimpleQueue(multiprocessing.queues.SimpleQueue):
  29. def _make_methods(self):
  30. if not isinstance(self._reader, ConnectionWrapper):
  31. self._reader: ConnectionWrapper = ConnectionWrapper(self._reader)
  32. self._writer: ConnectionWrapper = ConnectionWrapper(self._writer)
  33. super()._make_methods() # type: ignore[misc]