web_protocol.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667
  1. import asyncio
  2. import asyncio.streams
  3. import traceback
  4. import warnings
  5. from collections import deque
  6. from contextlib import suppress
  7. from html import escape as html_escape
  8. from http import HTTPStatus
  9. from logging import Logger
  10. from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast
  11. import yarl
  12. from .abc import AbstractAccessLogger, AbstractStreamWriter
  13. from .base_protocol import BaseProtocol
  14. from .helpers import CeilTimeout, current_task
  15. from .http import (
  16. HttpProcessingError,
  17. HttpRequestParser,
  18. HttpVersion10,
  19. RawRequestMessage,
  20. StreamWriter,
  21. )
  22. from .log import access_logger, server_logger
  23. from .streams import EMPTY_PAYLOAD, StreamReader
  24. from .tcp_helpers import tcp_keepalive
  25. from .web_exceptions import HTTPException
  26. from .web_log import AccessLogger
  27. from .web_request import BaseRequest
  28. from .web_response import Response, StreamResponse
  29. __all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
  30. if TYPE_CHECKING: # pragma: no cover
  31. from .web_server import Server
  32. _RequestFactory = Callable[
  33. [
  34. RawRequestMessage,
  35. StreamReader,
  36. "RequestHandler",
  37. AbstractStreamWriter,
  38. "asyncio.Task[None]",
  39. ],
  40. BaseRequest,
  41. ]
  42. _RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
  43. ERROR = RawRequestMessage(
  44. "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
  45. )
  46. class RequestPayloadError(Exception):
  47. """Payload parsing error."""
  48. class PayloadAccessError(Exception):
  49. """Payload was accessed after response was sent."""
  50. class RequestHandler(BaseProtocol):
  51. """HTTP protocol implementation.
  52. RequestHandler handles incoming HTTP request. It reads request line,
  53. request headers and request payload and calls handle_request() method.
  54. By default it always returns with 404 response.
  55. RequestHandler handles errors in incoming request, like bad
  56. status line, bad headers or incomplete payload. If any error occurs,
  57. connection gets closed.
  58. :param keepalive_timeout: number of seconds before closing
  59. keep-alive connection
  60. :type keepalive_timeout: int or None
  61. :param bool tcp_keepalive: TCP keep-alive is on, default is on
  62. :param bool debug: enable debug mode
  63. :param logger: custom logger object
  64. :type logger: aiohttp.log.server_logger
  65. :param access_log_class: custom class for access_logger
  66. :type access_log_class: aiohttp.abc.AbstractAccessLogger
  67. :param access_log: custom logging object
  68. :type access_log: aiohttp.log.server_logger
  69. :param str access_log_format: access log format string
  70. :param loop: Optional event loop
  71. :param int max_line_size: Optional maximum header line size
  72. :param int max_field_size: Optional maximum header field size
  73. :param int max_headers: Optional maximum header size
  74. """
  75. KEEPALIVE_RESCHEDULE_DELAY = 1
  76. __slots__ = (
  77. "_request_count",
  78. "_keepalive",
  79. "_manager",
  80. "_request_handler",
  81. "_request_factory",
  82. "_tcp_keepalive",
  83. "_keepalive_time",
  84. "_keepalive_handle",
  85. "_keepalive_timeout",
  86. "_lingering_time",
  87. "_messages",
  88. "_message_tail",
  89. "_waiter",
  90. "_error_handler",
  91. "_task_handler",
  92. "_upgrade",
  93. "_payload_parser",
  94. "_request_parser",
  95. "_reading_paused",
  96. "logger",
  97. "debug",
  98. "access_log",
  99. "access_logger",
  100. "_close",
  101. "_force_close",
  102. "_current_request",
  103. )
  104. def __init__(
  105. self,
  106. manager: "Server",
  107. *,
  108. loop: asyncio.AbstractEventLoop,
  109. keepalive_timeout: float = 75.0, # NGINX default is 75 secs
  110. tcp_keepalive: bool = True,
  111. logger: Logger = server_logger,
  112. access_log_class: Type[AbstractAccessLogger] = AccessLogger,
  113. access_log: Logger = access_logger,
  114. access_log_format: str = AccessLogger.LOG_FORMAT,
  115. debug: bool = False,
  116. max_line_size: int = 8190,
  117. max_headers: int = 32768,
  118. max_field_size: int = 8190,
  119. lingering_time: float = 10.0,
  120. read_bufsize: int = 2 ** 16,
  121. ):
  122. super().__init__(loop)
  123. self._request_count = 0
  124. self._keepalive = False
  125. self._current_request = None # type: Optional[BaseRequest]
  126. self._manager = manager # type: Optional[Server]
  127. self._request_handler = (
  128. manager.request_handler
  129. ) # type: Optional[_RequestHandler]
  130. self._request_factory = (
  131. manager.request_factory
  132. ) # type: Optional[_RequestFactory]
  133. self._tcp_keepalive = tcp_keepalive
  134. # placeholder to be replaced on keepalive timeout setup
  135. self._keepalive_time = 0.0
  136. self._keepalive_handle = None # type: Optional[asyncio.Handle]
  137. self._keepalive_timeout = keepalive_timeout
  138. self._lingering_time = float(lingering_time)
  139. self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
  140. self._message_tail = b""
  141. self._waiter = None # type: Optional[asyncio.Future[None]]
  142. self._error_handler = None # type: Optional[asyncio.Task[None]]
  143. self._task_handler = None # type: Optional[asyncio.Task[None]]
  144. self._upgrade = False
  145. self._payload_parser = None # type: Any
  146. self._request_parser = HttpRequestParser(
  147. self,
  148. loop,
  149. read_bufsize,
  150. max_line_size=max_line_size,
  151. max_field_size=max_field_size,
  152. max_headers=max_headers,
  153. payload_exception=RequestPayloadError,
  154. ) # type: Optional[HttpRequestParser]
  155. self.logger = logger
  156. self.debug = debug
  157. self.access_log = access_log
  158. if access_log:
  159. self.access_logger = access_log_class(
  160. access_log, access_log_format
  161. ) # type: Optional[AbstractAccessLogger]
  162. else:
  163. self.access_logger = None
  164. self._close = False
  165. self._force_close = False
  166. def __repr__(self) -> str:
  167. return "<{} {}>".format(
  168. self.__class__.__name__,
  169. "connected" if self.transport is not None else "disconnected",
  170. )
  171. @property
  172. def keepalive_timeout(self) -> float:
  173. return self._keepalive_timeout
  174. async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
  175. """Worker process is about to exit, we need cleanup everything and
  176. stop accepting requests. It is especially important for keep-alive
  177. connections."""
  178. self._force_close = True
  179. if self._keepalive_handle is not None:
  180. self._keepalive_handle.cancel()
  181. if self._waiter:
  182. self._waiter.cancel()
  183. # wait for handlers
  184. with suppress(asyncio.CancelledError, asyncio.TimeoutError):
  185. with CeilTimeout(timeout, loop=self._loop):
  186. if self._error_handler is not None and not self._error_handler.done():
  187. await self._error_handler
  188. if self._current_request is not None:
  189. self._current_request._cancel(asyncio.CancelledError())
  190. if self._task_handler is not None and not self._task_handler.done():
  191. await self._task_handler
  192. # force-close non-idle handler
  193. if self._task_handler is not None:
  194. self._task_handler.cancel()
  195. if self.transport is not None:
  196. self.transport.close()
  197. self.transport = None
  198. def connection_made(self, transport: asyncio.BaseTransport) -> None:
  199. super().connection_made(transport)
  200. real_transport = cast(asyncio.Transport, transport)
  201. if self._tcp_keepalive:
  202. tcp_keepalive(real_transport)
  203. self._task_handler = self._loop.create_task(self.start())
  204. assert self._manager is not None
  205. self._manager.connection_made(self, real_transport)
  206. def connection_lost(self, exc: Optional[BaseException]) -> None:
  207. if self._manager is None:
  208. return
  209. self._manager.connection_lost(self, exc)
  210. super().connection_lost(exc)
  211. self._manager = None
  212. self._force_close = True
  213. self._request_factory = None
  214. self._request_handler = None
  215. self._request_parser = None
  216. if self._keepalive_handle is not None:
  217. self._keepalive_handle.cancel()
  218. if self._current_request is not None:
  219. if exc is None:
  220. exc = ConnectionResetError("Connection lost")
  221. self._current_request._cancel(exc)
  222. if self._error_handler is not None:
  223. self._error_handler.cancel()
  224. if self._task_handler is not None:
  225. self._task_handler.cancel()
  226. if self._waiter is not None:
  227. self._waiter.cancel()
  228. self._task_handler = None
  229. if self._payload_parser is not None:
  230. self._payload_parser.feed_eof()
  231. self._payload_parser = None
  232. def set_parser(self, parser: Any) -> None:
  233. # Actual type is WebReader
  234. assert self._payload_parser is None
  235. self._payload_parser = parser
  236. if self._message_tail:
  237. self._payload_parser.feed_data(self._message_tail)
  238. self._message_tail = b""
  239. def eof_received(self) -> None:
  240. pass
  241. def data_received(self, data: bytes) -> None:
  242. if self._force_close or self._close:
  243. return
  244. # parse http messages
  245. if self._payload_parser is None and not self._upgrade:
  246. assert self._request_parser is not None
  247. try:
  248. messages, upgraded, tail = self._request_parser.feed_data(data)
  249. except HttpProcessingError as exc:
  250. # something happened during parsing
  251. self._error_handler = self._loop.create_task(
  252. self.handle_parse_error(
  253. StreamWriter(self, self._loop), 400, exc, exc.message
  254. )
  255. )
  256. self.close()
  257. except Exception as exc:
  258. # 500: internal error
  259. self._error_handler = self._loop.create_task(
  260. self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
  261. )
  262. self.close()
  263. else:
  264. if messages:
  265. # sometimes the parser returns no messages
  266. for (msg, payload) in messages:
  267. self._request_count += 1
  268. self._messages.append((msg, payload))
  269. waiter = self._waiter
  270. if waiter is not None:
  271. if not waiter.done():
  272. # don't set result twice
  273. waiter.set_result(None)
  274. self._upgrade = upgraded
  275. if upgraded and tail:
  276. self._message_tail = tail
  277. # no parser, just store
  278. elif self._payload_parser is None and self._upgrade and data:
  279. self._message_tail += data
  280. # feed payload
  281. elif data:
  282. eof, tail = self._payload_parser.feed_data(data)
  283. if eof:
  284. self.close()
  285. def keep_alive(self, val: bool) -> None:
  286. """Set keep-alive connection mode.
  287. :param bool val: new state.
  288. """
  289. self._keepalive = val
  290. if self._keepalive_handle:
  291. self._keepalive_handle.cancel()
  292. self._keepalive_handle = None
  293. def close(self) -> None:
  294. """Stop accepting new pipelinig messages and close
  295. connection when handlers done processing messages"""
  296. self._close = True
  297. if self._waiter:
  298. self._waiter.cancel()
  299. def force_close(self) -> None:
  300. """Force close connection"""
  301. self._force_close = True
  302. if self._waiter:
  303. self._waiter.cancel()
  304. if self.transport is not None:
  305. self.transport.close()
  306. self.transport = None
  307. def log_access(
  308. self, request: BaseRequest, response: StreamResponse, time: float
  309. ) -> None:
  310. if self.access_logger is not None:
  311. self.access_logger.log(request, response, self._loop.time() - time)
  312. def log_debug(self, *args: Any, **kw: Any) -> None:
  313. if self.debug:
  314. self.logger.debug(*args, **kw)
  315. def log_exception(self, *args: Any, **kw: Any) -> None:
  316. self.logger.exception(*args, **kw)
  317. def _process_keepalive(self) -> None:
  318. if self._force_close or not self._keepalive:
  319. return
  320. next = self._keepalive_time + self._keepalive_timeout
  321. # handler in idle state
  322. if self._waiter:
  323. if self._loop.time() > next:
  324. self.force_close()
  325. return
  326. # not all request handlers are done,
  327. # reschedule itself to next second
  328. self._keepalive_handle = self._loop.call_later(
  329. self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive
  330. )
  331. async def _handle_request(
  332. self,
  333. request: BaseRequest,
  334. start_time: float,
  335. ) -> Tuple[StreamResponse, bool]:
  336. assert self._request_handler is not None
  337. try:
  338. try:
  339. self._current_request = request
  340. resp = await self._request_handler(request)
  341. finally:
  342. self._current_request = None
  343. except HTTPException as exc:
  344. resp = Response(
  345. status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
  346. )
  347. reset = await self.finish_response(request, resp, start_time)
  348. except asyncio.CancelledError:
  349. raise
  350. except asyncio.TimeoutError as exc:
  351. self.log_debug("Request handler timed out.", exc_info=exc)
  352. resp = self.handle_error(request, 504)
  353. reset = await self.finish_response(request, resp, start_time)
  354. except Exception as exc:
  355. resp = self.handle_error(request, 500, exc)
  356. reset = await self.finish_response(request, resp, start_time)
  357. else:
  358. reset = await self.finish_response(request, resp, start_time)
  359. return resp, reset
  360. async def start(self) -> None:
  361. """Process incoming request.
  362. It reads request line, request headers and request payload, then
  363. calls handle_request() method. Subclass has to override
  364. handle_request(). start() handles various exceptions in request
  365. or response handling. Connection is being closed always unless
  366. keep_alive(True) specified.
  367. """
  368. loop = self._loop
  369. handler = self._task_handler
  370. assert handler is not None
  371. manager = self._manager
  372. assert manager is not None
  373. keepalive_timeout = self._keepalive_timeout
  374. resp = None
  375. assert self._request_factory is not None
  376. assert self._request_handler is not None
  377. while not self._force_close:
  378. if not self._messages:
  379. try:
  380. # wait for next request
  381. self._waiter = loop.create_future()
  382. await self._waiter
  383. except asyncio.CancelledError:
  384. break
  385. finally:
  386. self._waiter = None
  387. message, payload = self._messages.popleft()
  388. start = loop.time()
  389. manager.requests_count += 1
  390. writer = StreamWriter(self, loop)
  391. request = self._request_factory(message, payload, self, writer, handler)
  392. try:
  393. # a new task is used for copy context vars (#3406)
  394. task = self._loop.create_task(self._handle_request(request, start))
  395. try:
  396. resp, reset = await task
  397. except (asyncio.CancelledError, ConnectionError):
  398. self.log_debug("Ignored premature client disconnection")
  399. break
  400. # Deprecation warning (See #2415)
  401. if getattr(resp, "__http_exception__", False):
  402. warnings.warn(
  403. "returning HTTPException object is deprecated "
  404. "(#2415) and will be removed, "
  405. "please raise the exception instead",
  406. DeprecationWarning,
  407. )
  408. # Drop the processed task from asyncio.Task.all_tasks() early
  409. del task
  410. if reset:
  411. self.log_debug("Ignored premature client disconnection 2")
  412. break
  413. # notify server about keep-alive
  414. self._keepalive = bool(resp.keep_alive)
  415. # check payload
  416. if not payload.is_eof():
  417. lingering_time = self._lingering_time
  418. if not self._force_close and lingering_time:
  419. self.log_debug(
  420. "Start lingering close timer for %s sec.", lingering_time
  421. )
  422. now = loop.time()
  423. end_t = now + lingering_time
  424. with suppress(asyncio.TimeoutError, asyncio.CancelledError):
  425. while not payload.is_eof() and now < end_t:
  426. with CeilTimeout(end_t - now, loop=loop):
  427. # read and ignore
  428. await payload.readany()
  429. now = loop.time()
  430. # if payload still uncompleted
  431. if not payload.is_eof() and not self._force_close:
  432. self.log_debug("Uncompleted request.")
  433. self.close()
  434. payload.set_exception(PayloadAccessError())
  435. except asyncio.CancelledError:
  436. self.log_debug("Ignored premature client disconnection ")
  437. break
  438. except RuntimeError as exc:
  439. if self.debug:
  440. self.log_exception("Unhandled runtime exception", exc_info=exc)
  441. self.force_close()
  442. except Exception as exc:
  443. self.log_exception("Unhandled exception", exc_info=exc)
  444. self.force_close()
  445. finally:
  446. if self.transport is None and resp is not None:
  447. self.log_debug("Ignored premature client disconnection.")
  448. elif not self._force_close:
  449. if self._keepalive and not self._close:
  450. # start keep-alive timer
  451. if keepalive_timeout is not None:
  452. now = self._loop.time()
  453. self._keepalive_time = now
  454. if self._keepalive_handle is None:
  455. self._keepalive_handle = loop.call_at(
  456. now + keepalive_timeout, self._process_keepalive
  457. )
  458. else:
  459. break
  460. # remove handler, close transport if no handlers left
  461. if not self._force_close:
  462. self._task_handler = None
  463. if self.transport is not None and self._error_handler is None:
  464. self.transport.close()
  465. async def finish_response(
  466. self, request: BaseRequest, resp: StreamResponse, start_time: float
  467. ) -> bool:
  468. """
  469. Prepare the response and write_eof, then log access. This has to
  470. be called within the context of any exception so the access logger
  471. can get exception information. Returns True if the client disconnects
  472. prematurely.
  473. """
  474. if self._request_parser is not None:
  475. self._request_parser.set_upgraded(False)
  476. self._upgrade = False
  477. if self._message_tail:
  478. self._request_parser.feed_data(self._message_tail)
  479. self._message_tail = b""
  480. try:
  481. prepare_meth = resp.prepare
  482. except AttributeError:
  483. if resp is None:
  484. raise RuntimeError("Missing return " "statement on request handler")
  485. else:
  486. raise RuntimeError(
  487. "Web-handler should return "
  488. "a response instance, "
  489. "got {!r}".format(resp)
  490. )
  491. try:
  492. await prepare_meth(request)
  493. await resp.write_eof()
  494. except ConnectionError:
  495. self.log_access(request, resp, start_time)
  496. return True
  497. else:
  498. self.log_access(request, resp, start_time)
  499. return False
  500. def handle_error(
  501. self,
  502. request: BaseRequest,
  503. status: int = 500,
  504. exc: Optional[BaseException] = None,
  505. message: Optional[str] = None,
  506. ) -> StreamResponse:
  507. """Handle errors.
  508. Returns HTTP response with specific status code. Logs additional
  509. information. It always closes current connection."""
  510. self.log_exception("Error handling request", exc_info=exc)
  511. ct = "text/plain"
  512. if status == HTTPStatus.INTERNAL_SERVER_ERROR:
  513. title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
  514. msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
  515. tb = None
  516. if self.debug:
  517. with suppress(Exception):
  518. tb = traceback.format_exc()
  519. if "text/html" in request.headers.get("Accept", ""):
  520. if tb:
  521. tb = html_escape(tb)
  522. msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
  523. message = (
  524. "<html><head>"
  525. "<title>{title}</title>"
  526. "</head><body>\n<h1>{title}</h1>"
  527. "\n{msg}\n</body></html>\n"
  528. ).format(title=title, msg=msg)
  529. ct = "text/html"
  530. else:
  531. if tb:
  532. msg = tb
  533. message = title + "\n\n" + msg
  534. resp = Response(status=status, text=message, content_type=ct)
  535. resp.force_close()
  536. # some data already got sent, connection is broken
  537. if request.writer.output_size > 0 or self.transport is None:
  538. self.force_close()
  539. return resp
  540. async def handle_parse_error(
  541. self,
  542. writer: AbstractStreamWriter,
  543. status: int,
  544. exc: Optional[BaseException] = None,
  545. message: Optional[str] = None,
  546. ) -> None:
  547. task = current_task()
  548. assert task is not None
  549. request = BaseRequest(
  550. ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
  551. )
  552. resp = self.handle_error(request, status, exc, message)
  553. await resp.prepare(request)
  554. await resp.write_eof()
  555. if self.transport is not None:
  556. self.transport.close()
  557. self._error_handler = None