payload.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448
  1. import asyncio
  2. import enum
  3. import io
  4. import json
  5. import mimetypes
  6. import os
  7. import warnings
  8. from abc import ABC, abstractmethod
  9. from itertools import chain
  10. from typing import (
  11. IO,
  12. TYPE_CHECKING,
  13. Any,
  14. ByteString,
  15. Dict,
  16. Iterable,
  17. Optional,
  18. Text,
  19. TextIO,
  20. Tuple,
  21. Type,
  22. Union,
  23. )
  24. from multidict import CIMultiDict
  25. from . import hdrs
  26. from .abc import AbstractStreamWriter
  27. from .helpers import (
  28. PY_36,
  29. content_disposition_header,
  30. guess_filename,
  31. parse_mimetype,
  32. sentinel,
  33. )
  34. from .streams import StreamReader
  35. from .typedefs import JSONEncoder, _CIMultiDict
  36. __all__ = (
  37. "PAYLOAD_REGISTRY",
  38. "get_payload",
  39. "payload_type",
  40. "Payload",
  41. "BytesPayload",
  42. "StringPayload",
  43. "IOBasePayload",
  44. "BytesIOPayload",
  45. "BufferedReaderPayload",
  46. "TextIOPayload",
  47. "StringIOPayload",
  48. "JsonPayload",
  49. "AsyncIterablePayload",
  50. )
  51. TOO_LARGE_BYTES_BODY = 2 ** 20 # 1 MB
  52. if TYPE_CHECKING: # pragma: no cover
  53. from typing import List
  54. class LookupError(Exception):
  55. pass
  56. class Order(str, enum.Enum):
  57. normal = "normal"
  58. try_first = "try_first"
  59. try_last = "try_last"
  60. def get_payload(data: Any, *args: Any, **kwargs: Any) -> "Payload":
  61. return PAYLOAD_REGISTRY.get(data, *args, **kwargs)
  62. def register_payload(
  63. factory: Type["Payload"], type: Any, *, order: Order = Order.normal
  64. ) -> None:
  65. PAYLOAD_REGISTRY.register(factory, type, order=order)
  66. class payload_type:
  67. def __init__(self, type: Any, *, order: Order = Order.normal) -> None:
  68. self.type = type
  69. self.order = order
  70. def __call__(self, factory: Type["Payload"]) -> Type["Payload"]:
  71. register_payload(factory, self.type, order=self.order)
  72. return factory
  73. class PayloadRegistry:
  74. """Payload registry.
  75. note: we need zope.interface for more efficient adapter search
  76. """
  77. def __init__(self) -> None:
  78. self._first = [] # type: List[Tuple[Type[Payload], Any]]
  79. self._normal = [] # type: List[Tuple[Type[Payload], Any]]
  80. self._last = [] # type: List[Tuple[Type[Payload], Any]]
  81. def get(
  82. self, data: Any, *args: Any, _CHAIN: Any = chain, **kwargs: Any
  83. ) -> "Payload":
  84. if isinstance(data, Payload):
  85. return data
  86. for factory, type in _CHAIN(self._first, self._normal, self._last):
  87. if isinstance(data, type):
  88. return factory(data, *args, **kwargs)
  89. raise LookupError()
  90. def register(
  91. self, factory: Type["Payload"], type: Any, *, order: Order = Order.normal
  92. ) -> None:
  93. if order is Order.try_first:
  94. self._first.append((factory, type))
  95. elif order is Order.normal:
  96. self._normal.append((factory, type))
  97. elif order is Order.try_last:
  98. self._last.append((factory, type))
  99. else:
  100. raise ValueError(f"Unsupported order {order!r}")
  101. class Payload(ABC):
  102. _default_content_type = "application/octet-stream" # type: str
  103. _size = None # type: Optional[int]
  104. def __init__(
  105. self,
  106. value: Any,
  107. headers: Optional[
  108. Union[_CIMultiDict, Dict[str, str], Iterable[Tuple[str, str]]]
  109. ] = None,
  110. content_type: Optional[str] = sentinel,
  111. filename: Optional[str] = None,
  112. encoding: Optional[str] = None,
  113. **kwargs: Any,
  114. ) -> None:
  115. self._encoding = encoding
  116. self._filename = filename
  117. self._headers = CIMultiDict() # type: _CIMultiDict
  118. self._value = value
  119. if content_type is not sentinel and content_type is not None:
  120. self._headers[hdrs.CONTENT_TYPE] = content_type
  121. elif self._filename is not None:
  122. content_type = mimetypes.guess_type(self._filename)[0]
  123. if content_type is None:
  124. content_type = self._default_content_type
  125. self._headers[hdrs.CONTENT_TYPE] = content_type
  126. else:
  127. self._headers[hdrs.CONTENT_TYPE] = self._default_content_type
  128. self._headers.update(headers or {})
  129. @property
  130. def size(self) -> Optional[int]:
  131. """Size of the payload."""
  132. return self._size
  133. @property
  134. def filename(self) -> Optional[str]:
  135. """Filename of the payload."""
  136. return self._filename
  137. @property
  138. def headers(self) -> _CIMultiDict:
  139. """Custom item headers"""
  140. return self._headers
  141. @property
  142. def _binary_headers(self) -> bytes:
  143. return (
  144. "".join([k + ": " + v + "\r\n" for k, v in self.headers.items()]).encode(
  145. "utf-8"
  146. )
  147. + b"\r\n"
  148. )
  149. @property
  150. def encoding(self) -> Optional[str]:
  151. """Payload encoding"""
  152. return self._encoding
  153. @property
  154. def content_type(self) -> str:
  155. """Content type"""
  156. return self._headers[hdrs.CONTENT_TYPE]
  157. def set_content_disposition(
  158. self, disptype: str, quote_fields: bool = True, **params: Any
  159. ) -> None:
  160. """Sets ``Content-Disposition`` header."""
  161. self._headers[hdrs.CONTENT_DISPOSITION] = content_disposition_header(
  162. disptype, quote_fields=quote_fields, **params
  163. )
  164. @abstractmethod
  165. async def write(self, writer: AbstractStreamWriter) -> None:
  166. """Write payload.
  167. writer is an AbstractStreamWriter instance:
  168. """
  169. class BytesPayload(Payload):
  170. def __init__(self, value: ByteString, *args: Any, **kwargs: Any) -> None:
  171. if not isinstance(value, (bytes, bytearray, memoryview)):
  172. raise TypeError(
  173. "value argument must be byte-ish, not {!r}".format(type(value))
  174. )
  175. if "content_type" not in kwargs:
  176. kwargs["content_type"] = "application/octet-stream"
  177. super().__init__(value, *args, **kwargs)
  178. if isinstance(value, memoryview):
  179. self._size = value.nbytes
  180. else:
  181. self._size = len(value)
  182. if self._size > TOO_LARGE_BYTES_BODY:
  183. if PY_36:
  184. kwargs = {"source": self}
  185. else:
  186. kwargs = {}
  187. warnings.warn(
  188. "Sending a large body directly with raw bytes might"
  189. " lock the event loop. You should probably pass an "
  190. "io.BytesIO object instead",
  191. ResourceWarning,
  192. **kwargs,
  193. )
  194. async def write(self, writer: AbstractStreamWriter) -> None:
  195. await writer.write(self._value)
  196. class StringPayload(BytesPayload):
  197. def __init__(
  198. self,
  199. value: Text,
  200. *args: Any,
  201. encoding: Optional[str] = None,
  202. content_type: Optional[str] = None,
  203. **kwargs: Any,
  204. ) -> None:
  205. if encoding is None:
  206. if content_type is None:
  207. real_encoding = "utf-8"
  208. content_type = "text/plain; charset=utf-8"
  209. else:
  210. mimetype = parse_mimetype(content_type)
  211. real_encoding = mimetype.parameters.get("charset", "utf-8")
  212. else:
  213. if content_type is None:
  214. content_type = "text/plain; charset=%s" % encoding
  215. real_encoding = encoding
  216. super().__init__(
  217. value.encode(real_encoding),
  218. encoding=real_encoding,
  219. content_type=content_type,
  220. *args,
  221. **kwargs,
  222. )
  223. class StringIOPayload(StringPayload):
  224. def __init__(self, value: IO[str], *args: Any, **kwargs: Any) -> None:
  225. super().__init__(value.read(), *args, **kwargs)
  226. class IOBasePayload(Payload):
  227. def __init__(
  228. self, value: IO[Any], disposition: str = "attachment", *args: Any, **kwargs: Any
  229. ) -> None:
  230. if "filename" not in kwargs:
  231. kwargs["filename"] = guess_filename(value)
  232. super().__init__(value, *args, **kwargs)
  233. if self._filename is not None and disposition is not None:
  234. if hdrs.CONTENT_DISPOSITION not in self.headers:
  235. self.set_content_disposition(disposition, filename=self._filename)
  236. async def write(self, writer: AbstractStreamWriter) -> None:
  237. loop = asyncio.get_event_loop()
  238. try:
  239. chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16)
  240. while chunk:
  241. await writer.write(chunk)
  242. chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16)
  243. finally:
  244. await loop.run_in_executor(None, self._value.close)
  245. class TextIOPayload(IOBasePayload):
  246. def __init__(
  247. self,
  248. value: TextIO,
  249. *args: Any,
  250. encoding: Optional[str] = None,
  251. content_type: Optional[str] = None,
  252. **kwargs: Any,
  253. ) -> None:
  254. if encoding is None:
  255. if content_type is None:
  256. encoding = "utf-8"
  257. content_type = "text/plain; charset=utf-8"
  258. else:
  259. mimetype = parse_mimetype(content_type)
  260. encoding = mimetype.parameters.get("charset", "utf-8")
  261. else:
  262. if content_type is None:
  263. content_type = "text/plain; charset=%s" % encoding
  264. super().__init__(
  265. value,
  266. content_type=content_type,
  267. encoding=encoding,
  268. *args,
  269. **kwargs,
  270. )
  271. @property
  272. def size(self) -> Optional[int]:
  273. try:
  274. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  275. except OSError:
  276. return None
  277. async def write(self, writer: AbstractStreamWriter) -> None:
  278. loop = asyncio.get_event_loop()
  279. try:
  280. chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16)
  281. while chunk:
  282. await writer.write(chunk.encode(self._encoding))
  283. chunk = await loop.run_in_executor(None, self._value.read, 2 ** 16)
  284. finally:
  285. await loop.run_in_executor(None, self._value.close)
  286. class BytesIOPayload(IOBasePayload):
  287. @property
  288. def size(self) -> int:
  289. position = self._value.tell()
  290. end = self._value.seek(0, os.SEEK_END)
  291. self._value.seek(position)
  292. return end - position
  293. class BufferedReaderPayload(IOBasePayload):
  294. @property
  295. def size(self) -> Optional[int]:
  296. try:
  297. return os.fstat(self._value.fileno()).st_size - self._value.tell()
  298. except OSError:
  299. # data.fileno() is not supported, e.g.
  300. # io.BufferedReader(io.BytesIO(b'data'))
  301. return None
  302. class JsonPayload(BytesPayload):
  303. def __init__(
  304. self,
  305. value: Any,
  306. encoding: str = "utf-8",
  307. content_type: str = "application/json",
  308. dumps: JSONEncoder = json.dumps,
  309. *args: Any,
  310. **kwargs: Any,
  311. ) -> None:
  312. super().__init__(
  313. dumps(value).encode(encoding),
  314. content_type=content_type,
  315. encoding=encoding,
  316. *args,
  317. **kwargs,
  318. )
  319. if TYPE_CHECKING: # pragma: no cover
  320. from typing import AsyncIterable, AsyncIterator
  321. _AsyncIterator = AsyncIterator[bytes]
  322. _AsyncIterable = AsyncIterable[bytes]
  323. else:
  324. from collections.abc import AsyncIterable, AsyncIterator
  325. _AsyncIterator = AsyncIterator
  326. _AsyncIterable = AsyncIterable
  327. class AsyncIterablePayload(Payload):
  328. _iter = None # type: Optional[_AsyncIterator]
  329. def __init__(self, value: _AsyncIterable, *args: Any, **kwargs: Any) -> None:
  330. if not isinstance(value, AsyncIterable):
  331. raise TypeError(
  332. "value argument must support "
  333. "collections.abc.AsyncIterablebe interface, "
  334. "got {!r}".format(type(value))
  335. )
  336. if "content_type" not in kwargs:
  337. kwargs["content_type"] = "application/octet-stream"
  338. super().__init__(value, *args, **kwargs)
  339. self._iter = value.__aiter__()
  340. async def write(self, writer: AbstractStreamWriter) -> None:
  341. if self._iter:
  342. try:
  343. # iter is not None check prevents rare cases
  344. # when the case iterable is used twice
  345. while True:
  346. chunk = await self._iter.__anext__()
  347. await writer.write(chunk)
  348. except StopAsyncIteration:
  349. self._iter = None
  350. class StreamReaderPayload(AsyncIterablePayload):
  351. def __init__(self, value: StreamReader, *args: Any, **kwargs: Any) -> None:
  352. super().__init__(value.iter_any(), *args, **kwargs)
  353. PAYLOAD_REGISTRY = PayloadRegistry()
  354. PAYLOAD_REGISTRY.register(BytesPayload, (bytes, bytearray, memoryview))
  355. PAYLOAD_REGISTRY.register(StringPayload, str)
  356. PAYLOAD_REGISTRY.register(StringIOPayload, io.StringIO)
  357. PAYLOAD_REGISTRY.register(TextIOPayload, io.TextIOBase)
  358. PAYLOAD_REGISTRY.register(BytesIOPayload, io.BytesIO)
  359. PAYLOAD_REGISTRY.register(BufferedReaderPayload, (io.BufferedReader, io.BufferedRandom))
  360. PAYLOAD_REGISTRY.register(IOBasePayload, io.IOBase)
  361. PAYLOAD_REGISTRY.register(StreamReaderPayload, StreamReader)
  362. # try_last for giving a chance to more specialized async interables like
  363. # multidict.BodyPartReaderPayload override the default
  364. PAYLOAD_REGISTRY.register(AsyncIterablePayload, AsyncIterable, order=Order.try_last)