pytest_plugin.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380
  1. import asyncio
  2. import contextlib
  3. import warnings
  4. from collections.abc import Callable
  5. import pytest
  6. from aiohttp.helpers import PY_37, isasyncgenfunction
  7. from aiohttp.web import Application
  8. from .test_utils import (
  9. BaseTestServer,
  10. RawTestServer,
  11. TestClient,
  12. TestServer,
  13. loop_context,
  14. setup_test_loop,
  15. teardown_test_loop,
  16. unused_port as _unused_port,
  17. )
  18. try:
  19. import uvloop
  20. except ImportError: # pragma: no cover
  21. uvloop = None
  22. try:
  23. import tokio
  24. except ImportError: # pragma: no cover
  25. tokio = None
  26. def pytest_addoption(parser): # type: ignore
  27. parser.addoption(
  28. "--aiohttp-fast",
  29. action="store_true",
  30. default=False,
  31. help="run tests faster by disabling extra checks",
  32. )
  33. parser.addoption(
  34. "--aiohttp-loop",
  35. action="store",
  36. default="pyloop",
  37. help="run tests with specific loop: pyloop, uvloop, tokio or all",
  38. )
  39. parser.addoption(
  40. "--aiohttp-enable-loop-debug",
  41. action="store_true",
  42. default=False,
  43. help="enable event loop debug mode",
  44. )
  45. def pytest_fixture_setup(fixturedef): # type: ignore
  46. """
  47. Allow fixtures to be coroutines. Run coroutine fixtures in an event loop.
  48. """
  49. func = fixturedef.func
  50. if isasyncgenfunction(func):
  51. # async generator fixture
  52. is_async_gen = True
  53. elif asyncio.iscoroutinefunction(func):
  54. # regular async fixture
  55. is_async_gen = False
  56. else:
  57. # not an async fixture, nothing to do
  58. return
  59. strip_request = False
  60. if "request" not in fixturedef.argnames:
  61. fixturedef.argnames += ("request",)
  62. strip_request = True
  63. def wrapper(*args, **kwargs): # type: ignore
  64. request = kwargs["request"]
  65. if strip_request:
  66. del kwargs["request"]
  67. # if neither the fixture nor the test use the 'loop' fixture,
  68. # 'getfixturevalue' will fail because the test is not parameterized
  69. # (this can be removed someday if 'loop' is no longer parameterized)
  70. if "loop" not in request.fixturenames:
  71. raise Exception(
  72. "Asynchronous fixtures must depend on the 'loop' fixture or "
  73. "be used in tests depending from it."
  74. )
  75. _loop = request.getfixturevalue("loop")
  76. if is_async_gen:
  77. # for async generators, we need to advance the generator once,
  78. # then advance it again in a finalizer
  79. gen = func(*args, **kwargs)
  80. def finalizer(): # type: ignore
  81. try:
  82. return _loop.run_until_complete(gen.__anext__())
  83. except StopAsyncIteration:
  84. pass
  85. request.addfinalizer(finalizer)
  86. return _loop.run_until_complete(gen.__anext__())
  87. else:
  88. return _loop.run_until_complete(func(*args, **kwargs))
  89. fixturedef.func = wrapper
  90. @pytest.fixture
  91. def fast(request): # type: ignore
  92. """--fast config option"""
  93. return request.config.getoption("--aiohttp-fast")
  94. @pytest.fixture
  95. def loop_debug(request): # type: ignore
  96. """--enable-loop-debug config option"""
  97. return request.config.getoption("--aiohttp-enable-loop-debug")
  98. @contextlib.contextmanager
  99. def _runtime_warning_context(): # type: ignore
  100. """
  101. Context manager which checks for RuntimeWarnings, specifically to
  102. avoid "coroutine 'X' was never awaited" warnings being missed.
  103. If RuntimeWarnings occur in the context a RuntimeError is raised.
  104. """
  105. with warnings.catch_warnings(record=True) as _warnings:
  106. yield
  107. rw = [
  108. "{w.filename}:{w.lineno}:{w.message}".format(w=w)
  109. for w in _warnings
  110. if w.category == RuntimeWarning
  111. ]
  112. if rw:
  113. raise RuntimeError(
  114. "{} Runtime Warning{},\n{}".format(
  115. len(rw), "" if len(rw) == 1 else "s", "\n".join(rw)
  116. )
  117. )
  118. @contextlib.contextmanager
  119. def _passthrough_loop_context(loop, fast=False): # type: ignore
  120. """
  121. setups and tears down a loop unless one is passed in via the loop
  122. argument when it's passed straight through.
  123. """
  124. if loop:
  125. # loop already exists, pass it straight through
  126. yield loop
  127. else:
  128. # this shadows loop_context's standard behavior
  129. loop = setup_test_loop()
  130. yield loop
  131. teardown_test_loop(loop, fast=fast)
  132. def pytest_pycollect_makeitem(collector, name, obj): # type: ignore
  133. """
  134. Fix pytest collecting for coroutines.
  135. """
  136. if collector.funcnamefilter(name) and asyncio.iscoroutinefunction(obj):
  137. return list(collector._genfunctions(name, obj))
  138. def pytest_pyfunc_call(pyfuncitem): # type: ignore
  139. """
  140. Run coroutines in an event loop instead of a normal function call.
  141. """
  142. fast = pyfuncitem.config.getoption("--aiohttp-fast")
  143. if asyncio.iscoroutinefunction(pyfuncitem.function):
  144. existing_loop = pyfuncitem.funcargs.get(
  145. "proactor_loop"
  146. ) or pyfuncitem.funcargs.get("loop", None)
  147. with _runtime_warning_context():
  148. with _passthrough_loop_context(existing_loop, fast=fast) as _loop:
  149. testargs = {
  150. arg: pyfuncitem.funcargs[arg]
  151. for arg in pyfuncitem._fixtureinfo.argnames
  152. }
  153. _loop.run_until_complete(pyfuncitem.obj(**testargs))
  154. return True
  155. def pytest_generate_tests(metafunc): # type: ignore
  156. if "loop_factory" not in metafunc.fixturenames:
  157. return
  158. loops = metafunc.config.option.aiohttp_loop
  159. avail_factories = {"pyloop": asyncio.DefaultEventLoopPolicy}
  160. if uvloop is not None: # pragma: no cover
  161. avail_factories["uvloop"] = uvloop.EventLoopPolicy
  162. if tokio is not None: # pragma: no cover
  163. avail_factories["tokio"] = tokio.EventLoopPolicy
  164. if loops == "all":
  165. loops = "pyloop,uvloop?,tokio?"
  166. factories = {} # type: ignore
  167. for name in loops.split(","):
  168. required = not name.endswith("?")
  169. name = name.strip(" ?")
  170. if name not in avail_factories: # pragma: no cover
  171. if required:
  172. raise ValueError(
  173. "Unknown loop '%s', available loops: %s"
  174. % (name, list(factories.keys()))
  175. )
  176. else:
  177. continue
  178. factories[name] = avail_factories[name]
  179. metafunc.parametrize(
  180. "loop_factory", list(factories.values()), ids=list(factories.keys())
  181. )
  182. @pytest.fixture
  183. def loop(loop_factory, fast, loop_debug): # type: ignore
  184. """Return an instance of the event loop."""
  185. policy = loop_factory()
  186. asyncio.set_event_loop_policy(policy)
  187. with loop_context(fast=fast) as _loop:
  188. if loop_debug:
  189. _loop.set_debug(True) # pragma: no cover
  190. asyncio.set_event_loop(_loop)
  191. yield _loop
  192. @pytest.fixture
  193. def proactor_loop(): # type: ignore
  194. if not PY_37:
  195. policy = asyncio.get_event_loop_policy()
  196. policy._loop_factory = asyncio.ProactorEventLoop # type: ignore
  197. else:
  198. policy = asyncio.WindowsProactorEventLoopPolicy() # type: ignore
  199. asyncio.set_event_loop_policy(policy)
  200. with loop_context(policy.new_event_loop) as _loop:
  201. asyncio.set_event_loop(_loop)
  202. yield _loop
  203. @pytest.fixture
  204. def unused_port(aiohttp_unused_port): # type: ignore # pragma: no cover
  205. warnings.warn(
  206. "Deprecated, use aiohttp_unused_port fixture instead",
  207. DeprecationWarning,
  208. stacklevel=2,
  209. )
  210. return aiohttp_unused_port
  211. @pytest.fixture
  212. def aiohttp_unused_port(): # type: ignore
  213. """Return a port that is unused on the current host."""
  214. return _unused_port
  215. @pytest.fixture
  216. def aiohttp_server(loop): # type: ignore
  217. """Factory to create a TestServer instance, given an app.
  218. aiohttp_server(app, **kwargs)
  219. """
  220. servers = []
  221. async def go(app, *, port=None, **kwargs): # type: ignore
  222. server = TestServer(app, port=port)
  223. await server.start_server(loop=loop, **kwargs)
  224. servers.append(server)
  225. return server
  226. yield go
  227. async def finalize(): # type: ignore
  228. while servers:
  229. await servers.pop().close()
  230. loop.run_until_complete(finalize())
  231. @pytest.fixture
  232. def test_server(aiohttp_server): # type: ignore # pragma: no cover
  233. warnings.warn(
  234. "Deprecated, use aiohttp_server fixture instead",
  235. DeprecationWarning,
  236. stacklevel=2,
  237. )
  238. return aiohttp_server
  239. @pytest.fixture
  240. def aiohttp_raw_server(loop): # type: ignore
  241. """Factory to create a RawTestServer instance, given a web handler.
  242. aiohttp_raw_server(handler, **kwargs)
  243. """
  244. servers = []
  245. async def go(handler, *, port=None, **kwargs): # type: ignore
  246. server = RawTestServer(handler, port=port)
  247. await server.start_server(loop=loop, **kwargs)
  248. servers.append(server)
  249. return server
  250. yield go
  251. async def finalize(): # type: ignore
  252. while servers:
  253. await servers.pop().close()
  254. loop.run_until_complete(finalize())
  255. @pytest.fixture
  256. def raw_test_server(aiohttp_raw_server): # type: ignore # pragma: no cover
  257. warnings.warn(
  258. "Deprecated, use aiohttp_raw_server fixture instead",
  259. DeprecationWarning,
  260. stacklevel=2,
  261. )
  262. return aiohttp_raw_server
  263. @pytest.fixture
  264. def aiohttp_client(loop): # type: ignore
  265. """Factory to create a TestClient instance.
  266. aiohttp_client(app, **kwargs)
  267. aiohttp_client(server, **kwargs)
  268. aiohttp_client(raw_server, **kwargs)
  269. """
  270. clients = []
  271. async def go(__param, *args, server_kwargs=None, **kwargs): # type: ignore
  272. if isinstance(__param, Callable) and not isinstance( # type: ignore
  273. __param, (Application, BaseTestServer)
  274. ):
  275. __param = __param(loop, *args, **kwargs)
  276. kwargs = {}
  277. else:
  278. assert not args, "args should be empty"
  279. if isinstance(__param, Application):
  280. server_kwargs = server_kwargs or {}
  281. server = TestServer(__param, loop=loop, **server_kwargs)
  282. client = TestClient(server, loop=loop, **kwargs)
  283. elif isinstance(__param, BaseTestServer):
  284. client = TestClient(__param, loop=loop, **kwargs)
  285. else:
  286. raise ValueError("Unknown argument type: %r" % type(__param))
  287. await client.start_server()
  288. clients.append(client)
  289. return client
  290. yield go
  291. async def finalize(): # type: ignore
  292. while clients:
  293. await clients.pop().close()
  294. loop.run_until_complete(finalize())
  295. @pytest.fixture
  296. def test_client(aiohttp_client): # type: ignore # pragma: no cover
  297. warnings.warn(
  298. "Deprecated, use aiohttp_client fixture instead",
  299. DeprecationWarning,
  300. stacklevel=2,
  301. )
  302. return aiohttp_client