testwith.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. import unittest
  2. from warnings import catch_warnings
  3. from unittest.test.testmock.support import is_instance
  4. from unittest.mock import MagicMock, Mock, patch, sentinel, mock_open, call
  5. something = sentinel.Something
  6. something_else = sentinel.SomethingElse
  7. class SampleException(Exception): pass
  8. class WithTest(unittest.TestCase):
  9. def test_with_statement(self):
  10. with patch('%s.something' % __name__, sentinel.Something2):
  11. self.assertEqual(something, sentinel.Something2, "unpatched")
  12. self.assertEqual(something, sentinel.Something)
  13. def test_with_statement_exception(self):
  14. with self.assertRaises(SampleException):
  15. with patch('%s.something' % __name__, sentinel.Something2):
  16. self.assertEqual(something, sentinel.Something2, "unpatched")
  17. raise SampleException()
  18. self.assertEqual(something, sentinel.Something)
  19. def test_with_statement_as(self):
  20. with patch('%s.something' % __name__) as mock_something:
  21. self.assertEqual(something, mock_something, "unpatched")
  22. self.assertTrue(is_instance(mock_something, MagicMock),
  23. "patching wrong type")
  24. self.assertEqual(something, sentinel.Something)
  25. def test_patch_object_with_statement(self):
  26. class Foo(object):
  27. something = 'foo'
  28. original = Foo.something
  29. with patch.object(Foo, 'something'):
  30. self.assertNotEqual(Foo.something, original, "unpatched")
  31. self.assertEqual(Foo.something, original)
  32. def test_with_statement_nested(self):
  33. with catch_warnings(record=True):
  34. with patch('%s.something' % __name__) as mock_something, patch('%s.something_else' % __name__) as mock_something_else:
  35. self.assertEqual(something, mock_something, "unpatched")
  36. self.assertEqual(something_else, mock_something_else,
  37. "unpatched")
  38. self.assertEqual(something, sentinel.Something)
  39. self.assertEqual(something_else, sentinel.SomethingElse)
  40. def test_with_statement_specified(self):
  41. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  42. self.assertEqual(something, mock_something, "unpatched")
  43. self.assertEqual(mock_something, sentinel.Patched, "wrong patch")
  44. self.assertEqual(something, sentinel.Something)
  45. def testContextManagerMocking(self):
  46. mock = Mock()
  47. mock.__enter__ = Mock()
  48. mock.__exit__ = Mock()
  49. mock.__exit__.return_value = False
  50. with mock as m:
  51. self.assertEqual(m, mock.__enter__.return_value)
  52. mock.__enter__.assert_called_with()
  53. mock.__exit__.assert_called_with(None, None, None)
  54. def test_context_manager_with_magic_mock(self):
  55. mock = MagicMock()
  56. with self.assertRaises(TypeError):
  57. with mock:
  58. 'foo' + 3
  59. mock.__enter__.assert_called_with()
  60. self.assertTrue(mock.__exit__.called)
  61. def test_with_statement_same_attribute(self):
  62. with patch('%s.something' % __name__, sentinel.Patched) as mock_something:
  63. self.assertEqual(something, mock_something, "unpatched")
  64. with patch('%s.something' % __name__) as mock_again:
  65. self.assertEqual(something, mock_again, "unpatched")
  66. self.assertEqual(something, mock_something,
  67. "restored with wrong instance")
  68. self.assertEqual(something, sentinel.Something, "not restored")
  69. def test_with_statement_imbricated(self):
  70. with patch('%s.something' % __name__) as mock_something:
  71. self.assertEqual(something, mock_something, "unpatched")
  72. with patch('%s.something_else' % __name__) as mock_something_else:
  73. self.assertEqual(something_else, mock_something_else,
  74. "unpatched")
  75. self.assertEqual(something, sentinel.Something)
  76. self.assertEqual(something_else, sentinel.SomethingElse)
  77. def test_dict_context_manager(self):
  78. foo = {}
  79. with patch.dict(foo, {'a': 'b'}):
  80. self.assertEqual(foo, {'a': 'b'})
  81. self.assertEqual(foo, {})
  82. with self.assertRaises(NameError):
  83. with patch.dict(foo, {'a': 'b'}):
  84. self.assertEqual(foo, {'a': 'b'})
  85. raise NameError('Konrad')
  86. self.assertEqual(foo, {})
  87. def test_double_patch_instance_method(self):
  88. class C:
  89. def f(self): pass
  90. c = C()
  91. with patch.object(c, 'f', autospec=True) as patch1:
  92. with patch.object(c, 'f', autospec=True) as patch2:
  93. c.f()
  94. self.assertEqual(patch2.call_count, 1)
  95. self.assertEqual(patch1.call_count, 0)
  96. c.f()
  97. self.assertEqual(patch1.call_count, 1)
  98. class TestMockOpen(unittest.TestCase):
  99. def test_mock_open(self):
  100. mock = mock_open()
  101. with patch('%s.open' % __name__, mock, create=True) as patched:
  102. self.assertIs(patched, mock)
  103. open('foo')
  104. mock.assert_called_once_with('foo')
  105. def test_mock_open_context_manager(self):
  106. mock = mock_open()
  107. handle = mock.return_value
  108. with patch('%s.open' % __name__, mock, create=True):
  109. with open('foo') as f:
  110. f.read()
  111. expected_calls = [call('foo'), call().__enter__(), call().read(),
  112. call().__exit__(None, None, None)]
  113. self.assertEqual(mock.mock_calls, expected_calls)
  114. self.assertIs(f, handle)
  115. def test_mock_open_context_manager_multiple_times(self):
  116. mock = mock_open()
  117. with patch('%s.open' % __name__, mock, create=True):
  118. with open('foo') as f:
  119. f.read()
  120. with open('bar') as f:
  121. f.read()
  122. expected_calls = [
  123. call('foo'), call().__enter__(), call().read(),
  124. call().__exit__(None, None, None),
  125. call('bar'), call().__enter__(), call().read(),
  126. call().__exit__(None, None, None)]
  127. self.assertEqual(mock.mock_calls, expected_calls)
  128. def test_explicit_mock(self):
  129. mock = MagicMock()
  130. mock_open(mock)
  131. with patch('%s.open' % __name__, mock, create=True) as patched:
  132. self.assertIs(patched, mock)
  133. open('foo')
  134. mock.assert_called_once_with('foo')
  135. def test_read_data(self):
  136. mock = mock_open(read_data='foo')
  137. with patch('%s.open' % __name__, mock, create=True):
  138. h = open('bar')
  139. result = h.read()
  140. self.assertEqual(result, 'foo')
  141. def test_readline_data(self):
  142. # Check that readline will return all the lines from the fake file
  143. # And that once fully consumed, readline will return an empty string.
  144. mock = mock_open(read_data='foo\nbar\nbaz\n')
  145. with patch('%s.open' % __name__, mock, create=True):
  146. h = open('bar')
  147. line1 = h.readline()
  148. line2 = h.readline()
  149. line3 = h.readline()
  150. self.assertEqual(line1, 'foo\n')
  151. self.assertEqual(line2, 'bar\n')
  152. self.assertEqual(line3, 'baz\n')
  153. self.assertEqual(h.readline(), '')
  154. # Check that we properly emulate a file that doesn't end in a newline
  155. mock = mock_open(read_data='foo')
  156. with patch('%s.open' % __name__, mock, create=True):
  157. h = open('bar')
  158. result = h.readline()
  159. self.assertEqual(result, 'foo')
  160. self.assertEqual(h.readline(), '')
  161. def test_dunder_iter_data(self):
  162. # Check that dunder_iter will return all the lines from the fake file.
  163. mock = mock_open(read_data='foo\nbar\nbaz\n')
  164. with patch('%s.open' % __name__, mock, create=True):
  165. h = open('bar')
  166. lines = [l for l in h]
  167. self.assertEqual(lines[0], 'foo\n')
  168. self.assertEqual(lines[1], 'bar\n')
  169. self.assertEqual(lines[2], 'baz\n')
  170. self.assertEqual(h.readline(), '')
  171. with self.assertRaises(StopIteration):
  172. next(h)
  173. def test_next_data(self):
  174. # Check that next will correctly return the next available
  175. # line and plays well with the dunder_iter part.
  176. mock = mock_open(read_data='foo\nbar\nbaz\n')
  177. with patch('%s.open' % __name__, mock, create=True):
  178. h = open('bar')
  179. line1 = next(h)
  180. line2 = next(h)
  181. lines = [l for l in h]
  182. self.assertEqual(line1, 'foo\n')
  183. self.assertEqual(line2, 'bar\n')
  184. self.assertEqual(lines[0], 'baz\n')
  185. self.assertEqual(h.readline(), '')
  186. def test_readlines_data(self):
  187. # Test that emulating a file that ends in a newline character works
  188. mock = mock_open(read_data='foo\nbar\nbaz\n')
  189. with patch('%s.open' % __name__, mock, create=True):
  190. h = open('bar')
  191. result = h.readlines()
  192. self.assertEqual(result, ['foo\n', 'bar\n', 'baz\n'])
  193. # Test that files without a final newline will also be correctly
  194. # emulated
  195. mock = mock_open(read_data='foo\nbar\nbaz')
  196. with patch('%s.open' % __name__, mock, create=True):
  197. h = open('bar')
  198. result = h.readlines()
  199. self.assertEqual(result, ['foo\n', 'bar\n', 'baz'])
  200. def test_read_bytes(self):
  201. mock = mock_open(read_data=b'\xc6')
  202. with patch('%s.open' % __name__, mock, create=True):
  203. with open('abc', 'rb') as f:
  204. result = f.read()
  205. self.assertEqual(result, b'\xc6')
  206. def test_readline_bytes(self):
  207. m = mock_open(read_data=b'abc\ndef\nghi\n')
  208. with patch('%s.open' % __name__, m, create=True):
  209. with open('abc', 'rb') as f:
  210. line1 = f.readline()
  211. line2 = f.readline()
  212. line3 = f.readline()
  213. self.assertEqual(line1, b'abc\n')
  214. self.assertEqual(line2, b'def\n')
  215. self.assertEqual(line3, b'ghi\n')
  216. def test_readlines_bytes(self):
  217. m = mock_open(read_data=b'abc\ndef\nghi\n')
  218. with patch('%s.open' % __name__, m, create=True):
  219. with open('abc', 'rb') as f:
  220. result = f.readlines()
  221. self.assertEqual(result, [b'abc\n', b'def\n', b'ghi\n'])
  222. def test_mock_open_read_with_argument(self):
  223. # At one point calling read with an argument was broken
  224. # for mocks returned by mock_open
  225. some_data = 'foo\nbar\nbaz'
  226. mock = mock_open(read_data=some_data)
  227. self.assertEqual(mock().read(10), some_data[:10])
  228. self.assertEqual(mock().read(10), some_data[:10])
  229. f = mock()
  230. self.assertEqual(f.read(10), some_data[:10])
  231. self.assertEqual(f.read(10), some_data[10:])
  232. def test_interleaved_reads(self):
  233. # Test that calling read, readline, and readlines pulls data
  234. # sequentially from the data we preload with
  235. mock = mock_open(read_data='foo\nbar\nbaz\n')
  236. with patch('%s.open' % __name__, mock, create=True):
  237. h = open('bar')
  238. line1 = h.readline()
  239. rest = h.readlines()
  240. self.assertEqual(line1, 'foo\n')
  241. self.assertEqual(rest, ['bar\n', 'baz\n'])
  242. mock = mock_open(read_data='foo\nbar\nbaz\n')
  243. with patch('%s.open' % __name__, mock, create=True):
  244. h = open('bar')
  245. line1 = h.readline()
  246. rest = h.read()
  247. self.assertEqual(line1, 'foo\n')
  248. self.assertEqual(rest, 'bar\nbaz\n')
  249. def test_overriding_return_values(self):
  250. mock = mock_open(read_data='foo')
  251. handle = mock()
  252. handle.read.return_value = 'bar'
  253. handle.readline.return_value = 'bar'
  254. handle.readlines.return_value = ['bar']
  255. self.assertEqual(handle.read(), 'bar')
  256. self.assertEqual(handle.readline(), 'bar')
  257. self.assertEqual(handle.readlines(), ['bar'])
  258. # call repeatedly to check that a StopIteration is not propagated
  259. self.assertEqual(handle.readline(), 'bar')
  260. self.assertEqual(handle.readline(), 'bar')
  261. if __name__ == '__main__':
  262. unittest.main()