async_case.py 5.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142
  1. import asyncio
  2. import contextvars
  3. import inspect
  4. import warnings
  5. from .case import TestCase
  6. class IsolatedAsyncioTestCase(TestCase):
  7. # Names intentionally have a long prefix
  8. # to reduce a chance of clashing with user-defined attributes
  9. # from inherited test case
  10. #
  11. # The class doesn't call loop.run_until_complete(self.setUp()) and family
  12. # but uses a different approach:
  13. # 1. create a long-running task that reads self.setUp()
  14. # awaitable from queue along with a future
  15. # 2. await the awaitable object passing in and set the result
  16. # into the future object
  17. # 3. Outer code puts the awaitable and the future object into a queue
  18. # with waiting for the future
  19. # The trick is necessary because every run_until_complete() call
  20. # creates a new task with embedded ContextVar context.
  21. # To share contextvars between setUp(), test and tearDown() we need to execute
  22. # them inside the same task.
  23. # Note: the test case modifies event loop policy if the policy was not instantiated
  24. # yet.
  25. # asyncio.get_event_loop_policy() creates a default policy on demand but never
  26. # returns None
  27. # I believe this is not an issue in user level tests but python itself for testing
  28. # should reset a policy in every test module
  29. # by calling asyncio.set_event_loop_policy(None) in tearDownModule()
  30. def __init__(self, methodName='runTest'):
  31. super().__init__(methodName)
  32. self._asyncioRunner = None
  33. self._asyncioTestContext = contextvars.copy_context()
  34. async def asyncSetUp(self):
  35. pass
  36. async def asyncTearDown(self):
  37. pass
  38. def addAsyncCleanup(self, func, /, *args, **kwargs):
  39. # A trivial trampoline to addCleanup()
  40. # the function exists because it has a different semantics
  41. # and signature:
  42. # addCleanup() accepts regular functions
  43. # but addAsyncCleanup() accepts coroutines
  44. #
  45. # We intentionally don't add inspect.iscoroutinefunction() check
  46. # for func argument because there is no way
  47. # to check for async function reliably:
  48. # 1. It can be "async def func()" itself
  49. # 2. Class can implement "async def __call__()" method
  50. # 3. Regular "def func()" that returns awaitable object
  51. self.addCleanup(*(func, *args), **kwargs)
  52. async def enterAsyncContext(self, cm):
  53. """Enters the supplied asynchronous context manager.
  54. If successful, also adds its __aexit__ method as a cleanup
  55. function and returns the result of the __aenter__ method.
  56. """
  57. # We look up the special methods on the type to match the with
  58. # statement.
  59. cls = type(cm)
  60. try:
  61. enter = cls.__aenter__
  62. exit = cls.__aexit__
  63. except AttributeError:
  64. raise TypeError(f"'{cls.__module__}.{cls.__qualname__}' object does "
  65. f"not support the asynchronous context manager protocol"
  66. ) from None
  67. result = await enter(cm)
  68. self.addAsyncCleanup(exit, cm, None, None, None)
  69. return result
  70. def _callSetUp(self):
  71. # Force loop to be initialized and set as the current loop
  72. # so that setUp functions can use get_event_loop() and get the
  73. # correct loop instance.
  74. self._asyncioRunner.get_loop()
  75. self._asyncioTestContext.run(self.setUp)
  76. self._callAsync(self.asyncSetUp)
  77. def _callTestMethod(self, method):
  78. if self._callMaybeAsync(method) is not None:
  79. warnings.warn(f'It is deprecated to return a value that is not None from a '
  80. f'test case ({method})', DeprecationWarning, stacklevel=4)
  81. def _callTearDown(self):
  82. self._callAsync(self.asyncTearDown)
  83. self._asyncioTestContext.run(self.tearDown)
  84. def _callCleanup(self, function, *args, **kwargs):
  85. self._callMaybeAsync(function, *args, **kwargs)
  86. def _callAsync(self, func, /, *args, **kwargs):
  87. assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
  88. assert inspect.iscoroutinefunction(func), f'{func!r} is not an async function'
  89. return self._asyncioRunner.run(
  90. func(*args, **kwargs),
  91. context=self._asyncioTestContext
  92. )
  93. def _callMaybeAsync(self, func, /, *args, **kwargs):
  94. assert self._asyncioRunner is not None, 'asyncio runner is not initialized'
  95. if inspect.iscoroutinefunction(func):
  96. return self._asyncioRunner.run(
  97. func(*args, **kwargs),
  98. context=self._asyncioTestContext,
  99. )
  100. else:
  101. return self._asyncioTestContext.run(func, *args, **kwargs)
  102. def _setupAsyncioRunner(self):
  103. assert self._asyncioRunner is None, 'asyncio runner is already initialized'
  104. runner = asyncio.Runner(debug=True)
  105. self._asyncioRunner = runner
  106. def _tearDownAsyncioRunner(self):
  107. runner = self._asyncioRunner
  108. runner.close()
  109. def run(self, result=None):
  110. self._setupAsyncioRunner()
  111. try:
  112. return super().run(result)
  113. finally:
  114. self._tearDownAsyncioRunner()
  115. def debug(self):
  116. self._setupAsyncioRunner()
  117. super().debug()
  118. self._tearDownAsyncioRunner()
  119. def __del__(self):
  120. if self._asyncioRunner is not None:
  121. self._tearDownAsyncioRunner()