hooks.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290
  1. #-*- coding: iso-8859-1 -*-
  2. # pysqlite2/test/hooks.py: tests for various SQLite-specific hooks
  3. #
  4. # Copyright (C) 2006-2007 Gerhard Häring <gh@ghaering.de>
  5. #
  6. # This file is part of pysqlite.
  7. #
  8. # This software is provided 'as-is', without any express or implied
  9. # warranty. In no event will the authors be held liable for any damages
  10. # arising from the use of this software.
  11. #
  12. # Permission is granted to anyone to use this software for any purpose,
  13. # including commercial applications, and to alter it and redistribute it
  14. # freely, subject to the following restrictions:
  15. #
  16. # 1. The origin of this software must not be misrepresented; you must not
  17. # claim that you wrote the original software. If you use this software
  18. # in a product, an acknowledgment in the product documentation would be
  19. # appreciated but is not required.
  20. # 2. Altered source versions must be plainly marked as such, and must not be
  21. # misrepresented as being the original software.
  22. # 3. This notice may not be removed or altered from any source distribution.
  23. import unittest
  24. import sqlite3 as sqlite
  25. from test.support import TESTFN, unlink
  26. class CollationTests(unittest.TestCase):
  27. def CheckCreateCollationNotString(self):
  28. con = sqlite.connect(":memory:")
  29. with self.assertRaises(TypeError):
  30. con.create_collation(None, lambda x, y: (x > y) - (x < y))
  31. def CheckCreateCollationNotCallable(self):
  32. con = sqlite.connect(":memory:")
  33. with self.assertRaises(TypeError) as cm:
  34. con.create_collation("X", 42)
  35. self.assertEqual(str(cm.exception), 'parameter must be callable')
  36. def CheckCreateCollationNotAscii(self):
  37. con = sqlite.connect(":memory:")
  38. with self.assertRaises(sqlite.ProgrammingError):
  39. con.create_collation("collä", lambda x, y: (x > y) - (x < y))
  40. def CheckCreateCollationBadUpper(self):
  41. class BadUpperStr(str):
  42. def upper(self):
  43. return None
  44. con = sqlite.connect(":memory:")
  45. mycoll = lambda x, y: -((x > y) - (x < y))
  46. con.create_collation(BadUpperStr("mycoll"), mycoll)
  47. result = con.execute("""
  48. select x from (
  49. select 'a' as x
  50. union
  51. select 'b' as x
  52. ) order by x collate mycoll
  53. """).fetchall()
  54. self.assertEqual(result[0][0], 'b')
  55. self.assertEqual(result[1][0], 'a')
  56. @unittest.skipIf(sqlite.sqlite_version_info < (3, 2, 1),
  57. 'old SQLite versions crash on this test')
  58. def CheckCollationIsUsed(self):
  59. def mycoll(x, y):
  60. # reverse order
  61. return -((x > y) - (x < y))
  62. con = sqlite.connect(":memory:")
  63. con.create_collation("mycoll", mycoll)
  64. sql = """
  65. select x from (
  66. select 'a' as x
  67. union
  68. select 'b' as x
  69. union
  70. select 'c' as x
  71. ) order by x collate mycoll
  72. """
  73. result = con.execute(sql).fetchall()
  74. self.assertEqual(result, [('c',), ('b',), ('a',)],
  75. msg='the expected order was not returned')
  76. con.create_collation("mycoll", None)
  77. with self.assertRaises(sqlite.OperationalError) as cm:
  78. result = con.execute(sql).fetchall()
  79. self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
  80. def CheckCollationReturnsLargeInteger(self):
  81. def mycoll(x, y):
  82. # reverse order
  83. return -((x > y) - (x < y)) * 2**32
  84. con = sqlite.connect(":memory:")
  85. con.create_collation("mycoll", mycoll)
  86. sql = """
  87. select x from (
  88. select 'a' as x
  89. union
  90. select 'b' as x
  91. union
  92. select 'c' as x
  93. ) order by x collate mycoll
  94. """
  95. result = con.execute(sql).fetchall()
  96. self.assertEqual(result, [('c',), ('b',), ('a',)],
  97. msg="the expected order was not returned")
  98. def CheckCollationRegisterTwice(self):
  99. """
  100. Register two different collation functions under the same name.
  101. Verify that the last one is actually used.
  102. """
  103. con = sqlite.connect(":memory:")
  104. con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
  105. con.create_collation("mycoll", lambda x, y: -((x > y) - (x < y)))
  106. result = con.execute("""
  107. select x from (select 'a' as x union select 'b' as x) order by x collate mycoll
  108. """).fetchall()
  109. self.assertEqual(result[0][0], 'b')
  110. self.assertEqual(result[1][0], 'a')
  111. def CheckDeregisterCollation(self):
  112. """
  113. Register a collation, then deregister it. Make sure an error is raised if we try
  114. to use it.
  115. """
  116. con = sqlite.connect(":memory:")
  117. con.create_collation("mycoll", lambda x, y: (x > y) - (x < y))
  118. con.create_collation("mycoll", None)
  119. with self.assertRaises(sqlite.OperationalError) as cm:
  120. con.execute("select 'a' as x union select 'b' as x order by x collate mycoll")
  121. self.assertEqual(str(cm.exception), 'no such collation sequence: mycoll')
  122. class ProgressTests(unittest.TestCase):
  123. def CheckProgressHandlerUsed(self):
  124. """
  125. Test that the progress handler is invoked once it is set.
  126. """
  127. con = sqlite.connect(":memory:")
  128. progress_calls = []
  129. def progress():
  130. progress_calls.append(None)
  131. return 0
  132. con.set_progress_handler(progress, 1)
  133. con.execute("""
  134. create table foo(a, b)
  135. """)
  136. self.assertTrue(progress_calls)
  137. def CheckOpcodeCount(self):
  138. """
  139. Test that the opcode argument is respected.
  140. """
  141. con = sqlite.connect(":memory:")
  142. progress_calls = []
  143. def progress():
  144. progress_calls.append(None)
  145. return 0
  146. con.set_progress_handler(progress, 1)
  147. curs = con.cursor()
  148. curs.execute("""
  149. create table foo (a, b)
  150. """)
  151. first_count = len(progress_calls)
  152. progress_calls = []
  153. con.set_progress_handler(progress, 2)
  154. curs.execute("""
  155. create table bar (a, b)
  156. """)
  157. second_count = len(progress_calls)
  158. self.assertGreaterEqual(first_count, second_count)
  159. def CheckCancelOperation(self):
  160. """
  161. Test that returning a non-zero value stops the operation in progress.
  162. """
  163. con = sqlite.connect(":memory:")
  164. def progress():
  165. return 1
  166. con.set_progress_handler(progress, 1)
  167. curs = con.cursor()
  168. self.assertRaises(
  169. sqlite.OperationalError,
  170. curs.execute,
  171. "create table bar (a, b)")
  172. def CheckClearHandler(self):
  173. """
  174. Test that setting the progress handler to None clears the previously set handler.
  175. """
  176. con = sqlite.connect(":memory:")
  177. action = 0
  178. def progress():
  179. nonlocal action
  180. action = 1
  181. return 0
  182. con.set_progress_handler(progress, 1)
  183. con.set_progress_handler(None, 1)
  184. con.execute("select 1 union select 2 union select 3").fetchall()
  185. self.assertEqual(action, 0, "progress handler was not cleared")
  186. class TraceCallbackTests(unittest.TestCase):
  187. def CheckTraceCallbackUsed(self):
  188. """
  189. Test that the trace callback is invoked once it is set.
  190. """
  191. con = sqlite.connect(":memory:")
  192. traced_statements = []
  193. def trace(statement):
  194. traced_statements.append(statement)
  195. con.set_trace_callback(trace)
  196. con.execute("create table foo(a, b)")
  197. self.assertTrue(traced_statements)
  198. self.assertTrue(any("create table foo" in stmt for stmt in traced_statements))
  199. def CheckClearTraceCallback(self):
  200. """
  201. Test that setting the trace callback to None clears the previously set callback.
  202. """
  203. con = sqlite.connect(":memory:")
  204. traced_statements = []
  205. def trace(statement):
  206. traced_statements.append(statement)
  207. con.set_trace_callback(trace)
  208. con.set_trace_callback(None)
  209. con.execute("create table foo(a, b)")
  210. self.assertFalse(traced_statements, "trace callback was not cleared")
  211. def CheckUnicodeContent(self):
  212. """
  213. Test that the statement can contain unicode literals.
  214. """
  215. unicode_value = '\xf6\xe4\xfc\xd6\xc4\xdc\xdf\u20ac'
  216. con = sqlite.connect(":memory:")
  217. traced_statements = []
  218. def trace(statement):
  219. traced_statements.append(statement)
  220. con.set_trace_callback(trace)
  221. con.execute("create table foo(x)")
  222. # Can't execute bound parameters as their values don't appear
  223. # in traced statements before SQLite 3.6.21
  224. # (cf. http://www.sqlite.org/draft/releaselog/3_6_21.html)
  225. con.execute("insert into foo(x) values ('%s')" % unicode_value)
  226. con.commit()
  227. self.assertTrue(any(unicode_value in stmt for stmt in traced_statements),
  228. "Unicode data %s garbled in trace callback: %s"
  229. % (ascii(unicode_value), ', '.join(map(ascii, traced_statements))))
  230. @unittest.skipIf(sqlite.sqlite_version_info < (3, 3, 9), "sqlite3_prepare_v2 is not available")
  231. def CheckTraceCallbackContent(self):
  232. # set_trace_callback() shouldn't produce duplicate content (bpo-26187)
  233. traced_statements = []
  234. def trace(statement):
  235. traced_statements.append(statement)
  236. queries = ["create table foo(x)",
  237. "insert into foo(x) values(1)"]
  238. self.addCleanup(unlink, TESTFN)
  239. con1 = sqlite.connect(TESTFN, isolation_level=None)
  240. con2 = sqlite.connect(TESTFN)
  241. con1.set_trace_callback(trace)
  242. cur = con1.cursor()
  243. cur.execute(queries[0])
  244. con2.execute("create table bar(x)")
  245. cur.execute(queries[1])
  246. # Extract from SQLite 3.7.15 changelog:
  247. # Avoid invoking the sqlite3_trace() callback multiple times when a
  248. # statement is automatically reprepared due to SQLITE_SCHEMA errors.
  249. #
  250. # See bpo-40810
  251. if sqlite.sqlite_version_info < (3, 7, 15):
  252. queries.append(queries[-1])
  253. self.assertEqual(traced_statements, queries)
  254. def suite():
  255. collation_suite = unittest.makeSuite(CollationTests, "Check")
  256. progress_suite = unittest.makeSuite(ProgressTests, "Check")
  257. trace_suite = unittest.makeSuite(TraceCallbackTests, "Check")
  258. return unittest.TestSuite((collation_suite, progress_suite, trace_suite))
  259. def test():
  260. runner = unittest.TextTestRunner()
  261. runner.run(suite())
  262. if __name__ == "__main__":
  263. test()