worker.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252
  1. """Async gunicorn worker for aiohttp.web"""
  2. import asyncio
  3. import os
  4. import re
  5. import signal
  6. import sys
  7. from types import FrameType
  8. from typing import Any, Awaitable, Callable, Optional, Union # noqa
  9. from gunicorn.config import AccessLogFormat as GunicornAccessLogFormat
  10. from gunicorn.workers import base
  11. from aiohttp import web
  12. from .helpers import set_result
  13. from .web_app import Application
  14. from .web_log import AccessLogger
  15. try:
  16. import ssl
  17. SSLContext = ssl.SSLContext
  18. except ImportError: # pragma: no cover
  19. ssl = None # type: ignore
  20. SSLContext = object # type: ignore
  21. __all__ = ("GunicornWebWorker", "GunicornUVLoopWebWorker", "GunicornTokioWebWorker")
  22. class GunicornWebWorker(base.Worker):
  23. DEFAULT_AIOHTTP_LOG_FORMAT = AccessLogger.LOG_FORMAT
  24. DEFAULT_GUNICORN_LOG_FORMAT = GunicornAccessLogFormat.default
  25. def __init__(self, *args: Any, **kw: Any) -> None: # pragma: no cover
  26. super().__init__(*args, **kw)
  27. self._task = None # type: Optional[asyncio.Task[None]]
  28. self.exit_code = 0
  29. self._notify_waiter = None # type: Optional[asyncio.Future[bool]]
  30. def init_process(self) -> None:
  31. # create new event_loop after fork
  32. asyncio.get_event_loop().close()
  33. self.loop = asyncio.new_event_loop()
  34. asyncio.set_event_loop(self.loop)
  35. super().init_process()
  36. def run(self) -> None:
  37. self._task = self.loop.create_task(self._run())
  38. try: # ignore all finalization problems
  39. self.loop.run_until_complete(self._task)
  40. except Exception:
  41. self.log.exception("Exception in gunicorn worker")
  42. if sys.version_info >= (3, 6):
  43. self.loop.run_until_complete(self.loop.shutdown_asyncgens())
  44. self.loop.close()
  45. sys.exit(self.exit_code)
  46. async def _run(self) -> None:
  47. if isinstance(self.wsgi, Application):
  48. app = self.wsgi
  49. elif asyncio.iscoroutinefunction(self.wsgi):
  50. app = await self.wsgi()
  51. else:
  52. raise RuntimeError(
  53. "wsgi app should be either Application or "
  54. "async function returning Application, got {}".format(self.wsgi)
  55. )
  56. access_log = self.log.access_log if self.cfg.accesslog else None
  57. runner = web.AppRunner(
  58. app,
  59. logger=self.log,
  60. keepalive_timeout=self.cfg.keepalive,
  61. access_log=access_log,
  62. access_log_format=self._get_valid_log_format(self.cfg.access_log_format),
  63. )
  64. await runner.setup()
  65. ctx = self._create_ssl_context(self.cfg) if self.cfg.is_ssl else None
  66. runner = runner
  67. assert runner is not None
  68. server = runner.server
  69. assert server is not None
  70. for sock in self.sockets:
  71. site = web.SockSite(
  72. runner,
  73. sock,
  74. ssl_context=ctx,
  75. shutdown_timeout=self.cfg.graceful_timeout / 100 * 95,
  76. )
  77. await site.start()
  78. # If our parent changed then we shut down.
  79. pid = os.getpid()
  80. try:
  81. while self.alive: # type: ignore
  82. self.notify()
  83. cnt = server.requests_count
  84. if self.cfg.max_requests and cnt > self.cfg.max_requests:
  85. self.alive = False
  86. self.log.info("Max requests, shutting down: %s", self)
  87. elif pid == os.getpid() and self.ppid != os.getppid():
  88. self.alive = False
  89. self.log.info("Parent changed, shutting down: %s", self)
  90. else:
  91. await self._wait_next_notify()
  92. except BaseException:
  93. pass
  94. await runner.cleanup()
  95. def _wait_next_notify(self) -> "asyncio.Future[bool]":
  96. self._notify_waiter_done()
  97. loop = self.loop
  98. assert loop is not None
  99. self._notify_waiter = waiter = loop.create_future()
  100. self.loop.call_later(1.0, self._notify_waiter_done, waiter)
  101. return waiter
  102. def _notify_waiter_done(
  103. self, waiter: Optional["asyncio.Future[bool]"] = None
  104. ) -> None:
  105. if waiter is None:
  106. waiter = self._notify_waiter
  107. if waiter is not None:
  108. set_result(waiter, True)
  109. if waiter is self._notify_waiter:
  110. self._notify_waiter = None
  111. def init_signals(self) -> None:
  112. # Set up signals through the event loop API.
  113. self.loop.add_signal_handler(
  114. signal.SIGQUIT, self.handle_quit, signal.SIGQUIT, None
  115. )
  116. self.loop.add_signal_handler(
  117. signal.SIGTERM, self.handle_exit, signal.SIGTERM, None
  118. )
  119. self.loop.add_signal_handler(
  120. signal.SIGINT, self.handle_quit, signal.SIGINT, None
  121. )
  122. self.loop.add_signal_handler(
  123. signal.SIGWINCH, self.handle_winch, signal.SIGWINCH, None
  124. )
  125. self.loop.add_signal_handler(
  126. signal.SIGUSR1, self.handle_usr1, signal.SIGUSR1, None
  127. )
  128. self.loop.add_signal_handler(
  129. signal.SIGABRT, self.handle_abort, signal.SIGABRT, None
  130. )
  131. # Don't let SIGTERM and SIGUSR1 disturb active requests
  132. # by interrupting system calls
  133. signal.siginterrupt(signal.SIGTERM, False)
  134. signal.siginterrupt(signal.SIGUSR1, False)
  135. def handle_quit(self, sig: int, frame: FrameType) -> None:
  136. self.alive = False
  137. # worker_int callback
  138. self.cfg.worker_int(self)
  139. # wakeup closing process
  140. self._notify_waiter_done()
  141. def handle_abort(self, sig: int, frame: FrameType) -> None:
  142. self.alive = False
  143. self.exit_code = 1
  144. self.cfg.worker_abort(self)
  145. sys.exit(1)
  146. @staticmethod
  147. def _create_ssl_context(cfg: Any) -> "SSLContext":
  148. """Creates SSLContext instance for usage in asyncio.create_server.
  149. See ssl.SSLSocket.__init__ for more details.
  150. """
  151. if ssl is None: # pragma: no cover
  152. raise RuntimeError("SSL is not supported.")
  153. ctx = ssl.SSLContext(cfg.ssl_version)
  154. ctx.load_cert_chain(cfg.certfile, cfg.keyfile)
  155. ctx.verify_mode = cfg.cert_reqs
  156. if cfg.ca_certs:
  157. ctx.load_verify_locations(cfg.ca_certs)
  158. if cfg.ciphers:
  159. ctx.set_ciphers(cfg.ciphers)
  160. return ctx
  161. def _get_valid_log_format(self, source_format: str) -> str:
  162. if source_format == self.DEFAULT_GUNICORN_LOG_FORMAT:
  163. return self.DEFAULT_AIOHTTP_LOG_FORMAT
  164. elif re.search(r"%\([^\)]+\)", source_format):
  165. raise ValueError(
  166. "Gunicorn's style options in form of `%(name)s` are not "
  167. "supported for the log formatting. Please use aiohttp's "
  168. "format specification to configure access log formatting: "
  169. "http://docs.aiohttp.org/en/stable/logging.html"
  170. "#format-specification"
  171. )
  172. else:
  173. return source_format
  174. class GunicornUVLoopWebWorker(GunicornWebWorker):
  175. def init_process(self) -> None:
  176. import uvloop
  177. # Close any existing event loop before setting a
  178. # new policy.
  179. asyncio.get_event_loop().close()
  180. # Setup uvloop policy, so that every
  181. # asyncio.get_event_loop() will create an instance
  182. # of uvloop event loop.
  183. asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
  184. super().init_process()
  185. class GunicornTokioWebWorker(GunicornWebWorker):
  186. def init_process(self) -> None: # pragma: no cover
  187. import tokio
  188. # Close any existing event loop before setting a
  189. # new policy.
  190. asyncio.get_event_loop().close()
  191. # Setup tokio policy, so that every
  192. # asyncio.get_event_loop() will create an instance
  193. # of tokio event loop.
  194. asyncio.set_event_loop_policy(tokio.EventLoopPolicy())
  195. super().init_process()