async_case.py 6.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170
  1. import asyncio
  2. import inspect
  3. from .case import TestCase
  4. class IsolatedAsyncioTestCase(TestCase):
  5. # Names intentionally have a long prefix
  6. # to reduce a chance of clashing with user-defined attributes
  7. # from inherited test case
  8. #
  9. # The class doesn't call loop.run_until_complete(self.setUp()) and family
  10. # but uses a different approach:
  11. # 1. create a long-running task that reads self.setUp()
  12. # awaitable from queue along with a future
  13. # 2. await the awaitable object passing in and set the result
  14. # into the future object
  15. # 3. Outer code puts the awaitable and the future object into a queue
  16. # with waiting for the future
  17. # The trick is necessary because every run_until_complete() call
  18. # creates a new task with embedded ContextVar context.
  19. # To share contextvars between setUp(), test and tearDown() we need to execute
  20. # them inside the same task.
  21. # Note: the test case modifies event loop policy if the policy was not instantiated
  22. # yet.
  23. # asyncio.get_event_loop_policy() creates a default policy on demand but never
  24. # returns None
  25. # I believe this is not an issue in user level tests but python itself for testing
  26. # should reset a policy in every test module
  27. # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
  28. def __init__(self, methodName='runTest'):
  29. super().__init__(methodName)
  30. self._asyncioTestLoop = None
  31. self._asyncioCallsQueue = None
  32. async def asyncSetUp(self):
  33. pass
  34. async def asyncTearDown(self):
  35. pass
  36. def addAsyncCleanup(self, func, /, *args, **kwargs):
  37. # A trivial trampoline to addCleanup()
  38. # the function exists because it has a different semantics
  39. # and signature:
  40. # addCleanup() accepts regular functions
  41. # but addAsyncCleanup() accepts coroutines
  42. #
  43. # We intentionally don't add inspect.iscoroutinefunction() check
  44. # for func argument because there is no way
  45. # to check for async function reliably:
  46. # 1. It can be "async def func()" itself
  47. # 2. Class can implement "async def __call__()" method
  48. # 3. Regular "def func()" that returns awaitable object
  49. self.addCleanup(*(func, *args), **kwargs)
  50. def _callSetUp(self):
  51. self.setUp()
  52. self._callAsync(self.asyncSetUp)
  53. def _callTestMethod(self, method):
  54. self._callMaybeAsync(method)
  55. def _callTearDown(self):
  56. self._callAsync(self.asyncTearDown)
  57. self.tearDown()
  58. def _callCleanup(self, function, *args, **kwargs):
  59. self._callMaybeAsync(function, *args, **kwargs)
  60. def _callAsync(self, func, /, *args, **kwargs):
  61. assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
  62. ret = func(*args, **kwargs)
  63. assert inspect.isawaitable(ret), f'{func!r} returned non-awaitable'
  64. fut = self._asyncioTestLoop.create_future()
  65. self._asyncioCallsQueue.put_nowait((fut, ret))
  66. return self._asyncioTestLoop.run_until_complete(fut)
  67. def _callMaybeAsync(self, func, /, *args, **kwargs):
  68. assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
  69. ret = func(*args, **kwargs)
  70. if inspect.isawaitable(ret):
  71. fut = self._asyncioTestLoop.create_future()
  72. self._asyncioCallsQueue.put_nowait((fut, ret))
  73. return self._asyncioTestLoop.run_until_complete(fut)
  74. else:
  75. return ret
  76. async def _asyncioLoopRunner(self, fut):
  77. self._asyncioCallsQueue = queue = asyncio.Queue()
  78. fut.set_result(None)
  79. while True:
  80. query = await queue.get()
  81. queue.task_done()
  82. if query is None:
  83. return
  84. fut, awaitable = query
  85. try:
  86. ret = await awaitable
  87. if not fut.cancelled():
  88. fut.set_result(ret)
  89. except (SystemExit, KeyboardInterrupt):
  90. raise
  91. except (BaseException, asyncio.CancelledError) as ex:
  92. if not fut.cancelled():
  93. fut.set_exception(ex)
  94. def _setupAsyncioLoop(self):
  95. assert self._asyncioTestLoop is None, 'asyncio test loop already initialized'
  96. loop = asyncio.new_event_loop()
  97. asyncio.set_event_loop(loop)
  98. loop.set_debug(True)
  99. self._asyncioTestLoop = loop
  100. fut = loop.create_future()
  101. self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut))
  102. loop.run_until_complete(fut)
  103. def _tearDownAsyncioLoop(self):
  104. assert self._asyncioTestLoop is not None, 'asyncio test loop is not initialized'
  105. loop = self._asyncioTestLoop
  106. self._asyncioTestLoop = None
  107. self._asyncioCallsQueue.put_nowait(None)
  108. loop.run_until_complete(self._asyncioCallsQueue.join())
  109. try:
  110. # cancel all tasks
  111. to_cancel = asyncio.all_tasks(loop)
  112. if not to_cancel:
  113. return
  114. for task in to_cancel:
  115. task.cancel()
  116. loop.run_until_complete(
  117. asyncio.gather(*to_cancel, return_exceptions=True))
  118. for task in to_cancel:
  119. if task.cancelled():
  120. continue
  121. if task.exception() is not None:
  122. loop.call_exception_handler({
  123. 'message': 'unhandled exception during test shutdown',
  124. 'exception': task.exception(),
  125. 'task': task,
  126. })
  127. # shutdown asyncgens
  128. loop.run_until_complete(loop.shutdown_asyncgens())
  129. finally:
  130. # Prevent our executor environment from leaking to future tests.
  131. loop.run_until_complete(loop.shutdown_default_executor())
  132. asyncio.set_event_loop(None)
  133. loop.close()
  134. def run(self, result=None):
  135. self._setupAsyncioLoop()
  136. try:
  137. return super().run(result)
  138. finally:
  139. self._tearDownAsyncioLoop()
  140. def debug(self):
  141. self._setupAsyncioLoop()
  142. super().debug()
  143. self._tearDownAsyncioLoop()
  144. def __del__(self):
  145. if self._asyncioTestLoop is not None:
  146. self._tearDownAsyncioLoop()