http_writer.py 5.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182
  1. """Http related parsers and protocol."""
  2. import asyncio
  3. import collections
  4. import zlib
  5. from typing import Any, Awaitable, Callable, Optional, Union # noqa
  6. from multidict import CIMultiDict
  7. from .abc import AbstractStreamWriter
  8. from .base_protocol import BaseProtocol
  9. from .helpers import NO_EXTENSIONS
  10. __all__ = ("StreamWriter", "HttpVersion", "HttpVersion10", "HttpVersion11")
  11. HttpVersion = collections.namedtuple("HttpVersion", ["major", "minor"])
  12. HttpVersion10 = HttpVersion(1, 0)
  13. HttpVersion11 = HttpVersion(1, 1)
  14. _T_OnChunkSent = Optional[Callable[[bytes], Awaitable[None]]]
  15. class StreamWriter(AbstractStreamWriter):
  16. def __init__(
  17. self,
  18. protocol: BaseProtocol,
  19. loop: asyncio.AbstractEventLoop,
  20. on_chunk_sent: _T_OnChunkSent = None,
  21. ) -> None:
  22. self._protocol = protocol
  23. self._transport = protocol.transport
  24. self.loop = loop
  25. self.length = None
  26. self.chunked = False
  27. self.buffer_size = 0
  28. self.output_size = 0
  29. self._eof = False
  30. self._compress = None # type: Any
  31. self._drain_waiter = None
  32. self._on_chunk_sent = on_chunk_sent # type: _T_OnChunkSent
  33. @property
  34. def transport(self) -> Optional[asyncio.Transport]:
  35. return self._transport
  36. @property
  37. def protocol(self) -> BaseProtocol:
  38. return self._protocol
  39. def enable_chunking(self) -> None:
  40. self.chunked = True
  41. def enable_compression(self, encoding: str = "deflate") -> None:
  42. zlib_mode = 16 + zlib.MAX_WBITS if encoding == "gzip" else zlib.MAX_WBITS
  43. self._compress = zlib.compressobj(wbits=zlib_mode)
  44. def _write(self, chunk: bytes) -> None:
  45. size = len(chunk)
  46. self.buffer_size += size
  47. self.output_size += size
  48. if self._transport is None or self._transport.is_closing():
  49. raise ConnectionResetError("Cannot write to closing transport")
  50. self._transport.write(chunk)
  51. async def write(
  52. self, chunk: bytes, *, drain: bool = True, LIMIT: int = 0x10000
  53. ) -> None:
  54. """Writes chunk of data to a stream.
  55. write_eof() indicates end of stream.
  56. writer can't be used after write_eof() method being called.
  57. write() return drain future.
  58. """
  59. if self._on_chunk_sent is not None:
  60. await self._on_chunk_sent(chunk)
  61. if isinstance(chunk, memoryview):
  62. if chunk.nbytes != len(chunk):
  63. # just reshape it
  64. chunk = chunk.cast("c")
  65. if self._compress is not None:
  66. chunk = self._compress.compress(chunk)
  67. if not chunk:
  68. return
  69. if self.length is not None:
  70. chunk_len = len(chunk)
  71. if self.length >= chunk_len:
  72. self.length = self.length - chunk_len
  73. else:
  74. chunk = chunk[: self.length]
  75. self.length = 0
  76. if not chunk:
  77. return
  78. if chunk:
  79. if self.chunked:
  80. chunk_len_pre = ("%x\r\n" % len(chunk)).encode("ascii")
  81. chunk = chunk_len_pre + chunk + b"\r\n"
  82. self._write(chunk)
  83. if self.buffer_size > LIMIT and drain:
  84. self.buffer_size = 0
  85. await self.drain()
  86. async def write_headers(
  87. self, status_line: str, headers: "CIMultiDict[str]"
  88. ) -> None:
  89. """Write request/response status and headers."""
  90. # status + headers
  91. buf = _serialize_headers(status_line, headers)
  92. self._write(buf)
  93. async def write_eof(self, chunk: bytes = b"") -> None:
  94. if self._eof:
  95. return
  96. if chunk and self._on_chunk_sent is not None:
  97. await self._on_chunk_sent(chunk)
  98. if self._compress:
  99. if chunk:
  100. chunk = self._compress.compress(chunk)
  101. chunk = chunk + self._compress.flush()
  102. if chunk and self.chunked:
  103. chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
  104. chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
  105. else:
  106. if self.chunked:
  107. if chunk:
  108. chunk_len = ("%x\r\n" % len(chunk)).encode("ascii")
  109. chunk = chunk_len + chunk + b"\r\n0\r\n\r\n"
  110. else:
  111. chunk = b"0\r\n\r\n"
  112. if chunk:
  113. self._write(chunk)
  114. await self.drain()
  115. self._eof = True
  116. self._transport = None
  117. async def drain(self) -> None:
  118. """Flush the write buffer.
  119. The intended use is to write
  120. await w.write(data)
  121. await w.drain()
  122. """
  123. if self._protocol.transport is not None:
  124. await self._protocol._drain_helper()
  125. def _py_serialize_headers(status_line: str, headers: "CIMultiDict[str]") -> bytes:
  126. line = (
  127. status_line
  128. + "\r\n"
  129. + "".join([k + ": " + v + "\r\n" for k, v in headers.items()])
  130. )
  131. return line.encode("utf-8") + b"\r\n"
  132. _serialize_headers = _py_serialize_headers
  133. try:
  134. import aiohttp._http_writer as _http_writer # type: ignore
  135. _c_serialize_headers = _http_writer._serialize_headers
  136. if not NO_EXTENSIONS:
  137. _serialize_headers = _c_serialize_headers
  138. except ImportError:
  139. pass