web_runner.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381
  1. import asyncio
  2. import signal
  3. import socket
  4. from abc import ABC, abstractmethod
  5. from typing import Any, List, Optional, Set
  6. from yarl import URL
  7. from .web_app import Application
  8. from .web_server import Server
  9. try:
  10. from ssl import SSLContext
  11. except ImportError:
  12. SSLContext = object # type: ignore
  13. __all__ = (
  14. "BaseSite",
  15. "TCPSite",
  16. "UnixSite",
  17. "NamedPipeSite",
  18. "SockSite",
  19. "BaseRunner",
  20. "AppRunner",
  21. "ServerRunner",
  22. "GracefulExit",
  23. )
  24. class GracefulExit(SystemExit):
  25. code = 1
  26. def _raise_graceful_exit() -> None:
  27. raise GracefulExit()
  28. class BaseSite(ABC):
  29. __slots__ = ("_runner", "_shutdown_timeout", "_ssl_context", "_backlog", "_server")
  30. def __init__(
  31. self,
  32. runner: "BaseRunner",
  33. *,
  34. shutdown_timeout: float = 60.0,
  35. ssl_context: Optional[SSLContext] = None,
  36. backlog: int = 128,
  37. ) -> None:
  38. if runner.server is None:
  39. raise RuntimeError("Call runner.setup() before making a site")
  40. self._runner = runner
  41. self._shutdown_timeout = shutdown_timeout
  42. self._ssl_context = ssl_context
  43. self._backlog = backlog
  44. self._server = None # type: Optional[asyncio.AbstractServer]
  45. @property
  46. @abstractmethod
  47. def name(self) -> str:
  48. pass # pragma: no cover
  49. @abstractmethod
  50. async def start(self) -> None:
  51. self._runner._reg_site(self)
  52. async def stop(self) -> None:
  53. self._runner._check_site(self)
  54. if self._server is None:
  55. self._runner._unreg_site(self)
  56. return # not started yet
  57. self._server.close()
  58. # named pipes do not have wait_closed property
  59. if hasattr(self._server, "wait_closed"):
  60. await self._server.wait_closed()
  61. await self._runner.shutdown()
  62. assert self._runner.server
  63. await self._runner.server.shutdown(self._shutdown_timeout)
  64. self._runner._unreg_site(self)
  65. class TCPSite(BaseSite):
  66. __slots__ = ("_host", "_port", "_reuse_address", "_reuse_port")
  67. def __init__(
  68. self,
  69. runner: "BaseRunner",
  70. host: Optional[str] = None,
  71. port: Optional[int] = None,
  72. *,
  73. shutdown_timeout: float = 60.0,
  74. ssl_context: Optional[SSLContext] = None,
  75. backlog: int = 128,
  76. reuse_address: Optional[bool] = None,
  77. reuse_port: Optional[bool] = None,
  78. ) -> None:
  79. super().__init__(
  80. runner,
  81. shutdown_timeout=shutdown_timeout,
  82. ssl_context=ssl_context,
  83. backlog=backlog,
  84. )
  85. self._host = host
  86. if port is None:
  87. port = 8443 if self._ssl_context else 8080
  88. self._port = port
  89. self._reuse_address = reuse_address
  90. self._reuse_port = reuse_port
  91. @property
  92. def name(self) -> str:
  93. scheme = "https" if self._ssl_context else "http"
  94. host = "0.0.0.0" if self._host is None else self._host
  95. return str(URL.build(scheme=scheme, host=host, port=self._port))
  96. async def start(self) -> None:
  97. await super().start()
  98. loop = asyncio.get_event_loop()
  99. server = self._runner.server
  100. assert server is not None
  101. self._server = await loop.create_server(
  102. server,
  103. self._host,
  104. self._port,
  105. ssl=self._ssl_context,
  106. backlog=self._backlog,
  107. reuse_address=self._reuse_address,
  108. reuse_port=self._reuse_port,
  109. )
  110. class UnixSite(BaseSite):
  111. __slots__ = ("_path",)
  112. def __init__(
  113. self,
  114. runner: "BaseRunner",
  115. path: str,
  116. *,
  117. shutdown_timeout: float = 60.0,
  118. ssl_context: Optional[SSLContext] = None,
  119. backlog: int = 128,
  120. ) -> None:
  121. super().__init__(
  122. runner,
  123. shutdown_timeout=shutdown_timeout,
  124. ssl_context=ssl_context,
  125. backlog=backlog,
  126. )
  127. self._path = path
  128. @property
  129. def name(self) -> str:
  130. scheme = "https" if self._ssl_context else "http"
  131. return f"{scheme}://unix:{self._path}:"
  132. async def start(self) -> None:
  133. await super().start()
  134. loop = asyncio.get_event_loop()
  135. server = self._runner.server
  136. assert server is not None
  137. self._server = await loop.create_unix_server(
  138. server, self._path, ssl=self._ssl_context, backlog=self._backlog
  139. )
  140. class NamedPipeSite(BaseSite):
  141. __slots__ = ("_path",)
  142. def __init__(
  143. self, runner: "BaseRunner", path: str, *, shutdown_timeout: float = 60.0
  144. ) -> None:
  145. loop = asyncio.get_event_loop()
  146. if not isinstance(loop, asyncio.ProactorEventLoop): # type: ignore
  147. raise RuntimeError(
  148. "Named Pipes only available in proactor" "loop under windows"
  149. )
  150. super().__init__(runner, shutdown_timeout=shutdown_timeout)
  151. self._path = path
  152. @property
  153. def name(self) -> str:
  154. return self._path
  155. async def start(self) -> None:
  156. await super().start()
  157. loop = asyncio.get_event_loop()
  158. server = self._runner.server
  159. assert server is not None
  160. _server = await loop.start_serving_pipe(server, self._path) # type: ignore
  161. self._server = _server[0]
  162. class SockSite(BaseSite):
  163. __slots__ = ("_sock", "_name")
  164. def __init__(
  165. self,
  166. runner: "BaseRunner",
  167. sock: socket.socket,
  168. *,
  169. shutdown_timeout: float = 60.0,
  170. ssl_context: Optional[SSLContext] = None,
  171. backlog: int = 128,
  172. ) -> None:
  173. super().__init__(
  174. runner,
  175. shutdown_timeout=shutdown_timeout,
  176. ssl_context=ssl_context,
  177. backlog=backlog,
  178. )
  179. self._sock = sock
  180. scheme = "https" if self._ssl_context else "http"
  181. if hasattr(socket, "AF_UNIX") and sock.family == socket.AF_UNIX:
  182. name = f"{scheme}://unix:{sock.getsockname()}:"
  183. else:
  184. host, port = sock.getsockname()[:2]
  185. name = str(URL.build(scheme=scheme, host=host, port=port))
  186. self._name = name
  187. @property
  188. def name(self) -> str:
  189. return self._name
  190. async def start(self) -> None:
  191. await super().start()
  192. loop = asyncio.get_event_loop()
  193. server = self._runner.server
  194. assert server is not None
  195. self._server = await loop.create_server(
  196. server, sock=self._sock, ssl=self._ssl_context, backlog=self._backlog
  197. )
  198. class BaseRunner(ABC):
  199. __slots__ = ("_handle_signals", "_kwargs", "_server", "_sites")
  200. def __init__(self, *, handle_signals: bool = False, **kwargs: Any) -> None:
  201. self._handle_signals = handle_signals
  202. self._kwargs = kwargs
  203. self._server = None # type: Optional[Server]
  204. self._sites = [] # type: List[BaseSite]
  205. @property
  206. def server(self) -> Optional[Server]:
  207. return self._server
  208. @property
  209. def addresses(self) -> List[Any]:
  210. ret = [] # type: List[Any]
  211. for site in self._sites:
  212. server = site._server
  213. if server is not None:
  214. sockets = server.sockets
  215. if sockets is not None:
  216. for sock in sockets:
  217. ret.append(sock.getsockname())
  218. return ret
  219. @property
  220. def sites(self) -> Set[BaseSite]:
  221. return set(self._sites)
  222. async def setup(self) -> None:
  223. loop = asyncio.get_event_loop()
  224. if self._handle_signals:
  225. try:
  226. loop.add_signal_handler(signal.SIGINT, _raise_graceful_exit)
  227. loop.add_signal_handler(signal.SIGTERM, _raise_graceful_exit)
  228. except NotImplementedError: # pragma: no cover
  229. # add_signal_handler is not implemented on Windows
  230. pass
  231. self._server = await self._make_server()
  232. @abstractmethod
  233. async def shutdown(self) -> None:
  234. pass # pragma: no cover
  235. async def cleanup(self) -> None:
  236. loop = asyncio.get_event_loop()
  237. if self._server is None:
  238. # no started yet, do nothing
  239. return
  240. # The loop over sites is intentional, an exception on gather()
  241. # leaves self._sites in unpredictable state.
  242. # The loop guaranties that a site is either deleted on success or
  243. # still present on failure
  244. for site in list(self._sites):
  245. await site.stop()
  246. await self._cleanup_server()
  247. self._server = None
  248. if self._handle_signals:
  249. try:
  250. loop.remove_signal_handler(signal.SIGINT)
  251. loop.remove_signal_handler(signal.SIGTERM)
  252. except NotImplementedError: # pragma: no cover
  253. # remove_signal_handler is not implemented on Windows
  254. pass
  255. @abstractmethod
  256. async def _make_server(self) -> Server:
  257. pass # pragma: no cover
  258. @abstractmethod
  259. async def _cleanup_server(self) -> None:
  260. pass # pragma: no cover
  261. def _reg_site(self, site: BaseSite) -> None:
  262. if site in self._sites:
  263. raise RuntimeError(f"Site {site} is already registered in runner {self}")
  264. self._sites.append(site)
  265. def _check_site(self, site: BaseSite) -> None:
  266. if site not in self._sites:
  267. raise RuntimeError(f"Site {site} is not registered in runner {self}")
  268. def _unreg_site(self, site: BaseSite) -> None:
  269. if site not in self._sites:
  270. raise RuntimeError(f"Site {site} is not registered in runner {self}")
  271. self._sites.remove(site)
  272. class ServerRunner(BaseRunner):
  273. """Low-level web server runner"""
  274. __slots__ = ("_web_server",)
  275. def __init__(
  276. self, web_server: Server, *, handle_signals: bool = False, **kwargs: Any
  277. ) -> None:
  278. super().__init__(handle_signals=handle_signals, **kwargs)
  279. self._web_server = web_server
  280. async def shutdown(self) -> None:
  281. pass
  282. async def _make_server(self) -> Server:
  283. return self._web_server
  284. async def _cleanup_server(self) -> None:
  285. pass
  286. class AppRunner(BaseRunner):
  287. """Web Application runner"""
  288. __slots__ = ("_app",)
  289. def __init__(
  290. self, app: Application, *, handle_signals: bool = False, **kwargs: Any
  291. ) -> None:
  292. super().__init__(handle_signals=handle_signals, **kwargs)
  293. if not isinstance(app, Application):
  294. raise TypeError(
  295. "The first argument should be web.Application "
  296. "instance, got {!r}".format(app)
  297. )
  298. self._app = app
  299. @property
  300. def app(self) -> Application:
  301. return self._app
  302. async def shutdown(self) -> None:
  303. await self._app.shutdown()
  304. async def _make_server(self) -> Server:
  305. loop = asyncio.get_event_loop()
  306. self._app._set_loop(loop)
  307. self._app.on_startup.freeze()
  308. await self._app.startup()
  309. self._app.freeze()
  310. return self._app._make_handler(loop=loop, **self._kwargs)
  311. async def _cleanup_server(self) -> None:
  312. await self._app.cleanup()