base_protocol.py 2.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. import asyncio
  2. from typing import Optional, cast
  3. from .tcp_helpers import tcp_nodelay
  4. class BaseProtocol(asyncio.Protocol):
  5. __slots__ = (
  6. "_loop",
  7. "_paused",
  8. "_drain_waiter",
  9. "_connection_lost",
  10. "_reading_paused",
  11. "transport",
  12. )
  13. def __init__(self, loop: asyncio.AbstractEventLoop) -> None:
  14. self._loop = loop # type: asyncio.AbstractEventLoop
  15. self._paused = False
  16. self._drain_waiter = None # type: Optional[asyncio.Future[None]]
  17. self._connection_lost = False
  18. self._reading_paused = False
  19. self.transport = None # type: Optional[asyncio.Transport]
  20. def pause_writing(self) -> None:
  21. assert not self._paused
  22. self._paused = True
  23. def resume_writing(self) -> None:
  24. assert self._paused
  25. self._paused = False
  26. waiter = self._drain_waiter
  27. if waiter is not None:
  28. self._drain_waiter = None
  29. if not waiter.done():
  30. waiter.set_result(None)
  31. def pause_reading(self) -> None:
  32. if not self._reading_paused and self.transport is not None:
  33. try:
  34. self.transport.pause_reading()
  35. except (AttributeError, NotImplementedError, RuntimeError):
  36. pass
  37. self._reading_paused = True
  38. def resume_reading(self) -> None:
  39. if self._reading_paused and self.transport is not None:
  40. try:
  41. self.transport.resume_reading()
  42. except (AttributeError, NotImplementedError, RuntimeError):
  43. pass
  44. self._reading_paused = False
  45. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  46. tr = cast(asyncio.Transport, transport)
  47. tcp_nodelay(tr, True)
  48. self.transport = tr
  49. def connection_lost(self, exc: Optional[BaseException]) -> None:
  50. self._connection_lost = True
  51. # Wake up the writer if currently paused.
  52. self.transport = None
  53. if not self._paused:
  54. return
  55. waiter = self._drain_waiter
  56. if waiter is None:
  57. return
  58. self._drain_waiter = None
  59. if waiter.done():
  60. return
  61. if exc is None:
  62. waiter.set_result(None)
  63. else:
  64. waiter.set_exception(exc)
  65. async def _drain_helper(self) -> None:
  66. if self._connection_lost:
  67. raise ConnectionResetError("Connection lost")
  68. if not self._paused:
  69. return
  70. waiter = self._drain_waiter
  71. assert waiter is None or waiter.cancelled()
  72. waiter = self._loop.create_future()
  73. self._drain_waiter = waiter
  74. await waiter