123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667 |
- import asyncio
- import asyncio.streams
- import traceback
- import warnings
- from collections import deque
- from contextlib import suppress
- from html import escape as html_escape
- from http import HTTPStatus
- from logging import Logger
- from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Tuple, Type, cast
- import yarl
- from .abc import AbstractAccessLogger, AbstractStreamWriter
- from .base_protocol import BaseProtocol
- from .helpers import CeilTimeout, current_task
- from .http import (
- HttpProcessingError,
- HttpRequestParser,
- HttpVersion10,
- RawRequestMessage,
- StreamWriter,
- )
- from .log import access_logger, server_logger
- from .streams import EMPTY_PAYLOAD, StreamReader
- from .tcp_helpers import tcp_keepalive
- from .web_exceptions import HTTPException
- from .web_log import AccessLogger
- from .web_request import BaseRequest
- from .web_response import Response, StreamResponse
- __all__ = ("RequestHandler", "RequestPayloadError", "PayloadAccessError")
- if TYPE_CHECKING: # pragma: no cover
- from .web_server import Server
- _RequestFactory = Callable[
- [
- RawRequestMessage,
- StreamReader,
- "RequestHandler",
- AbstractStreamWriter,
- "asyncio.Task[None]",
- ],
- BaseRequest,
- ]
- _RequestHandler = Callable[[BaseRequest], Awaitable[StreamResponse]]
- ERROR = RawRequestMessage(
- "UNKNOWN", "/", HttpVersion10, {}, {}, True, False, False, False, yarl.URL("/")
- )
- class RequestPayloadError(Exception):
- """Payload parsing error."""
- class PayloadAccessError(Exception):
- """Payload was accessed after response was sent."""
- class RequestHandler(BaseProtocol):
- """HTTP protocol implementation.
- RequestHandler handles incoming HTTP request. It reads request line,
- request headers and request payload and calls handle_request() method.
- By default it always returns with 404 response.
- RequestHandler handles errors in incoming request, like bad
- status line, bad headers or incomplete payload. If any error occurs,
- connection gets closed.
- :param keepalive_timeout: number of seconds before closing
- keep-alive connection
- :type keepalive_timeout: int or None
- :param bool tcp_keepalive: TCP keep-alive is on, default is on
- :param bool debug: enable debug mode
- :param logger: custom logger object
- :type logger: aiohttp.log.server_logger
- :param access_log_class: custom class for access_logger
- :type access_log_class: aiohttp.abc.AbstractAccessLogger
- :param access_log: custom logging object
- :type access_log: aiohttp.log.server_logger
- :param str access_log_format: access log format string
- :param loop: Optional event loop
- :param int max_line_size: Optional maximum header line size
- :param int max_field_size: Optional maximum header field size
- :param int max_headers: Optional maximum header size
- """
- KEEPALIVE_RESCHEDULE_DELAY = 1
- __slots__ = (
- "_request_count",
- "_keepalive",
- "_manager",
- "_request_handler",
- "_request_factory",
- "_tcp_keepalive",
- "_keepalive_time",
- "_keepalive_handle",
- "_keepalive_timeout",
- "_lingering_time",
- "_messages",
- "_message_tail",
- "_waiter",
- "_error_handler",
- "_task_handler",
- "_upgrade",
- "_payload_parser",
- "_request_parser",
- "_reading_paused",
- "logger",
- "debug",
- "access_log",
- "access_logger",
- "_close",
- "_force_close",
- "_current_request",
- )
- def __init__(
- self,
- manager: "Server",
- *,
- loop: asyncio.AbstractEventLoop,
- keepalive_timeout: float = 75.0, # NGINX default is 75 secs
- tcp_keepalive: bool = True,
- logger: Logger = server_logger,
- access_log_class: Type[AbstractAccessLogger] = AccessLogger,
- access_log: Logger = access_logger,
- access_log_format: str = AccessLogger.LOG_FORMAT,
- debug: bool = False,
- max_line_size: int = 8190,
- max_headers: int = 32768,
- max_field_size: int = 8190,
- lingering_time: float = 10.0,
- read_bufsize: int = 2 ** 16,
- ):
- super().__init__(loop)
- self._request_count = 0
- self._keepalive = False
- self._current_request = None # type: Optional[BaseRequest]
- self._manager = manager # type: Optional[Server]
- self._request_handler = (
- manager.request_handler
- ) # type: Optional[_RequestHandler]
- self._request_factory = (
- manager.request_factory
- ) # type: Optional[_RequestFactory]
- self._tcp_keepalive = tcp_keepalive
- # placeholder to be replaced on keepalive timeout setup
- self._keepalive_time = 0.0
- self._keepalive_handle = None # type: Optional[asyncio.Handle]
- self._keepalive_timeout = keepalive_timeout
- self._lingering_time = float(lingering_time)
- self._messages = deque() # type: Any # Python 3.5 has no typing.Deque
- self._message_tail = b""
- self._waiter = None # type: Optional[asyncio.Future[None]]
- self._error_handler = None # type: Optional[asyncio.Task[None]]
- self._task_handler = None # type: Optional[asyncio.Task[None]]
- self._upgrade = False
- self._payload_parser = None # type: Any
- self._request_parser = HttpRequestParser(
- self,
- loop,
- read_bufsize,
- max_line_size=max_line_size,
- max_field_size=max_field_size,
- max_headers=max_headers,
- payload_exception=RequestPayloadError,
- ) # type: Optional[HttpRequestParser]
- self.logger = logger
- self.debug = debug
- self.access_log = access_log
- if access_log:
- self.access_logger = access_log_class(
- access_log, access_log_format
- ) # type: Optional[AbstractAccessLogger]
- else:
- self.access_logger = None
- self._close = False
- self._force_close = False
- def __repr__(self) -> str:
- return "<{} {}>".format(
- self.__class__.__name__,
- "connected" if self.transport is not None else "disconnected",
- )
- @property
- def keepalive_timeout(self) -> float:
- return self._keepalive_timeout
- async def shutdown(self, timeout: Optional[float] = 15.0) -> None:
- """Worker process is about to exit, we need cleanup everything and
- stop accepting requests. It is especially important for keep-alive
- connections."""
- self._force_close = True
- if self._keepalive_handle is not None:
- self._keepalive_handle.cancel()
- if self._waiter:
- self._waiter.cancel()
- # wait for handlers
- with suppress(asyncio.CancelledError, asyncio.TimeoutError):
- with CeilTimeout(timeout, loop=self._loop):
- if self._error_handler is not None and not self._error_handler.done():
- await self._error_handler
- if self._current_request is not None:
- self._current_request._cancel(asyncio.CancelledError())
- if self._task_handler is not None and not self._task_handler.done():
- await self._task_handler
- # force-close non-idle handler
- if self._task_handler is not None:
- self._task_handler.cancel()
- if self.transport is not None:
- self.transport.close()
- self.transport = None
- def connection_made(self, transport: asyncio.BaseTransport) -> None:
- super().connection_made(transport)
- real_transport = cast(asyncio.Transport, transport)
- if self._tcp_keepalive:
- tcp_keepalive(real_transport)
- self._task_handler = self._loop.create_task(self.start())
- assert self._manager is not None
- self._manager.connection_made(self, real_transport)
- def connection_lost(self, exc: Optional[BaseException]) -> None:
- if self._manager is None:
- return
- self._manager.connection_lost(self, exc)
- super().connection_lost(exc)
- self._manager = None
- self._force_close = True
- self._request_factory = None
- self._request_handler = None
- self._request_parser = None
- if self._keepalive_handle is not None:
- self._keepalive_handle.cancel()
- if self._current_request is not None:
- if exc is None:
- exc = ConnectionResetError("Connection lost")
- self._current_request._cancel(exc)
- if self._error_handler is not None:
- self._error_handler.cancel()
- if self._task_handler is not None:
- self._task_handler.cancel()
- if self._waiter is not None:
- self._waiter.cancel()
- self._task_handler = None
- if self._payload_parser is not None:
- self._payload_parser.feed_eof()
- self._payload_parser = None
- def set_parser(self, parser: Any) -> None:
- # Actual type is WebReader
- assert self._payload_parser is None
- self._payload_parser = parser
- if self._message_tail:
- self._payload_parser.feed_data(self._message_tail)
- self._message_tail = b""
- def eof_received(self) -> None:
- pass
- def data_received(self, data: bytes) -> None:
- if self._force_close or self._close:
- return
- # parse http messages
- if self._payload_parser is None and not self._upgrade:
- assert self._request_parser is not None
- try:
- messages, upgraded, tail = self._request_parser.feed_data(data)
- except HttpProcessingError as exc:
- # something happened during parsing
- self._error_handler = self._loop.create_task(
- self.handle_parse_error(
- StreamWriter(self, self._loop), 400, exc, exc.message
- )
- )
- self.close()
- except Exception as exc:
- # 500: internal error
- self._error_handler = self._loop.create_task(
- self.handle_parse_error(StreamWriter(self, self._loop), 500, exc)
- )
- self.close()
- else:
- if messages:
- # sometimes the parser returns no messages
- for (msg, payload) in messages:
- self._request_count += 1
- self._messages.append((msg, payload))
- waiter = self._waiter
- if waiter is not None:
- if not waiter.done():
- # don't set result twice
- waiter.set_result(None)
- self._upgrade = upgraded
- if upgraded and tail:
- self._message_tail = tail
- # no parser, just store
- elif self._payload_parser is None and self._upgrade and data:
- self._message_tail += data
- # feed payload
- elif data:
- eof, tail = self._payload_parser.feed_data(data)
- if eof:
- self.close()
- def keep_alive(self, val: bool) -> None:
- """Set keep-alive connection mode.
- :param bool val: new state.
- """
- self._keepalive = val
- if self._keepalive_handle:
- self._keepalive_handle.cancel()
- self._keepalive_handle = None
- def close(self) -> None:
- """Stop accepting new pipelinig messages and close
- connection when handlers done processing messages"""
- self._close = True
- if self._waiter:
- self._waiter.cancel()
- def force_close(self) -> None:
- """Force close connection"""
- self._force_close = True
- if self._waiter:
- self._waiter.cancel()
- if self.transport is not None:
- self.transport.close()
- self.transport = None
- def log_access(
- self, request: BaseRequest, response: StreamResponse, time: float
- ) -> None:
- if self.access_logger is not None:
- self.access_logger.log(request, response, self._loop.time() - time)
- def log_debug(self, *args: Any, **kw: Any) -> None:
- if self.debug:
- self.logger.debug(*args, **kw)
- def log_exception(self, *args: Any, **kw: Any) -> None:
- self.logger.exception(*args, **kw)
- def _process_keepalive(self) -> None:
- if self._force_close or not self._keepalive:
- return
- next = self._keepalive_time + self._keepalive_timeout
- # handler in idle state
- if self._waiter:
- if self._loop.time() > next:
- self.force_close()
- return
- # not all request handlers are done,
- # reschedule itself to next second
- self._keepalive_handle = self._loop.call_later(
- self.KEEPALIVE_RESCHEDULE_DELAY, self._process_keepalive
- )
- async def _handle_request(
- self,
- request: BaseRequest,
- start_time: float,
- ) -> Tuple[StreamResponse, bool]:
- assert self._request_handler is not None
- try:
- try:
- self._current_request = request
- resp = await self._request_handler(request)
- finally:
- self._current_request = None
- except HTTPException as exc:
- resp = Response(
- status=exc.status, reason=exc.reason, text=exc.text, headers=exc.headers
- )
- reset = await self.finish_response(request, resp, start_time)
- except asyncio.CancelledError:
- raise
- except asyncio.TimeoutError as exc:
- self.log_debug("Request handler timed out.", exc_info=exc)
- resp = self.handle_error(request, 504)
- reset = await self.finish_response(request, resp, start_time)
- except Exception as exc:
- resp = self.handle_error(request, 500, exc)
- reset = await self.finish_response(request, resp, start_time)
- else:
- reset = await self.finish_response(request, resp, start_time)
- return resp, reset
- async def start(self) -> None:
- """Process incoming request.
- It reads request line, request headers and request payload, then
- calls handle_request() method. Subclass has to override
- handle_request(). start() handles various exceptions in request
- or response handling. Connection is being closed always unless
- keep_alive(True) specified.
- """
- loop = self._loop
- handler = self._task_handler
- assert handler is not None
- manager = self._manager
- assert manager is not None
- keepalive_timeout = self._keepalive_timeout
- resp = None
- assert self._request_factory is not None
- assert self._request_handler is not None
- while not self._force_close:
- if not self._messages:
- try:
- # wait for next request
- self._waiter = loop.create_future()
- await self._waiter
- except asyncio.CancelledError:
- break
- finally:
- self._waiter = None
- message, payload = self._messages.popleft()
- start = loop.time()
- manager.requests_count += 1
- writer = StreamWriter(self, loop)
- request = self._request_factory(message, payload, self, writer, handler)
- try:
- # a new task is used for copy context vars (#3406)
- task = self._loop.create_task(self._handle_request(request, start))
- try:
- resp, reset = await task
- except (asyncio.CancelledError, ConnectionError):
- self.log_debug("Ignored premature client disconnection")
- break
- # Deprecation warning (See #2415)
- if getattr(resp, "__http_exception__", False):
- warnings.warn(
- "returning HTTPException object is deprecated "
- "(#2415) and will be removed, "
- "please raise the exception instead",
- DeprecationWarning,
- )
- # Drop the processed task from asyncio.Task.all_tasks() early
- del task
- if reset:
- self.log_debug("Ignored premature client disconnection 2")
- break
- # notify server about keep-alive
- self._keepalive = bool(resp.keep_alive)
- # check payload
- if not payload.is_eof():
- lingering_time = self._lingering_time
- if not self._force_close and lingering_time:
- self.log_debug(
- "Start lingering close timer for %s sec.", lingering_time
- )
- now = loop.time()
- end_t = now + lingering_time
- with suppress(asyncio.TimeoutError, asyncio.CancelledError):
- while not payload.is_eof() and now < end_t:
- with CeilTimeout(end_t - now, loop=loop):
- # read and ignore
- await payload.readany()
- now = loop.time()
- # if payload still uncompleted
- if not payload.is_eof() and not self._force_close:
- self.log_debug("Uncompleted request.")
- self.close()
- payload.set_exception(PayloadAccessError())
- except asyncio.CancelledError:
- self.log_debug("Ignored premature client disconnection ")
- break
- except RuntimeError as exc:
- if self.debug:
- self.log_exception("Unhandled runtime exception", exc_info=exc)
- self.force_close()
- except Exception as exc:
- self.log_exception("Unhandled exception", exc_info=exc)
- self.force_close()
- finally:
- if self.transport is None and resp is not None:
- self.log_debug("Ignored premature client disconnection.")
- elif not self._force_close:
- if self._keepalive and not self._close:
- # start keep-alive timer
- if keepalive_timeout is not None:
- now = self._loop.time()
- self._keepalive_time = now
- if self._keepalive_handle is None:
- self._keepalive_handle = loop.call_at(
- now + keepalive_timeout, self._process_keepalive
- )
- else:
- break
- # remove handler, close transport if no handlers left
- if not self._force_close:
- self._task_handler = None
- if self.transport is not None and self._error_handler is None:
- self.transport.close()
- async def finish_response(
- self, request: BaseRequest, resp: StreamResponse, start_time: float
- ) -> bool:
- """
- Prepare the response and write_eof, then log access. This has to
- be called within the context of any exception so the access logger
- can get exception information. Returns True if the client disconnects
- prematurely.
- """
- if self._request_parser is not None:
- self._request_parser.set_upgraded(False)
- self._upgrade = False
- if self._message_tail:
- self._request_parser.feed_data(self._message_tail)
- self._message_tail = b""
- try:
- prepare_meth = resp.prepare
- except AttributeError:
- if resp is None:
- raise RuntimeError("Missing return " "statement on request handler")
- else:
- raise RuntimeError(
- "Web-handler should return "
- "a response instance, "
- "got {!r}".format(resp)
- )
- try:
- await prepare_meth(request)
- await resp.write_eof()
- except ConnectionError:
- self.log_access(request, resp, start_time)
- return True
- else:
- self.log_access(request, resp, start_time)
- return False
- def handle_error(
- self,
- request: BaseRequest,
- status: int = 500,
- exc: Optional[BaseException] = None,
- message: Optional[str] = None,
- ) -> StreamResponse:
- """Handle errors.
- Returns HTTP response with specific status code. Logs additional
- information. It always closes current connection."""
- self.log_exception("Error handling request", exc_info=exc)
- ct = "text/plain"
- if status == HTTPStatus.INTERNAL_SERVER_ERROR:
- title = "{0.value} {0.phrase}".format(HTTPStatus.INTERNAL_SERVER_ERROR)
- msg = HTTPStatus.INTERNAL_SERVER_ERROR.description
- tb = None
- if self.debug:
- with suppress(Exception):
- tb = traceback.format_exc()
- if "text/html" in request.headers.get("Accept", ""):
- if tb:
- tb = html_escape(tb)
- msg = f"<h2>Traceback:</h2>\n<pre>{tb}</pre>"
- message = (
- "<html><head>"
- "<title>{title}</title>"
- "</head><body>\n<h1>{title}</h1>"
- "\n{msg}\n</body></html>\n"
- ).format(title=title, msg=msg)
- ct = "text/html"
- else:
- if tb:
- msg = tb
- message = title + "\n\n" + msg
- resp = Response(status=status, text=message, content_type=ct)
- resp.force_close()
- # some data already got sent, connection is broken
- if request.writer.output_size > 0 or self.transport is None:
- self.force_close()
- return resp
- async def handle_parse_error(
- self,
- writer: AbstractStreamWriter,
- status: int,
- exc: Optional[BaseException] = None,
- message: Optional[str] = None,
- ) -> None:
- task = current_task()
- assert task is not None
- request = BaseRequest(
- ERROR, EMPTY_PAYLOAD, self, writer, task, self._loop # type: ignore
- )
- resp = self.handle_error(request, status, exc, message)
- await resp.prepare(request)
- await resp.write_eof()
- if self.transport is not None:
- self.transport.close()
- self._error_handler = None
|