userfunctions.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548
  1. # pysqlite2/test/userfunctions.py: tests for user-defined functions and
  2. # aggregates.
  3. #
  4. # Copyright (C) 2005-2007 Gerhard Häring <gh@ghaering.de>
  5. #
  6. # This file is part of pysqlite.
  7. #
  8. # This software is provided 'as-is', without any express or implied
  9. # warranty. In no event will the authors be held liable for any damages
  10. # arising from the use of this software.
  11. #
  12. # Permission is granted to anyone to use this software for any purpose,
  13. # including commercial applications, and to alter it and redistribute it
  14. # freely, subject to the following restrictions:
  15. #
  16. # 1. The origin of this software must not be misrepresented; you must not
  17. # claim that you wrote the original software. If you use this software
  18. # in a product, an acknowledgment in the product documentation would be
  19. # appreciated but is not required.
  20. # 2. Altered source versions must be plainly marked as such, and must not be
  21. # misrepresented as being the original software.
  22. # 3. This notice may not be removed or altered from any source distribution.
  23. import unittest
  24. import unittest.mock
  25. import sqlite3 as sqlite
  26. def func_returntext():
  27. return "foo"
  28. def func_returntextwithnull():
  29. return "1\x002"
  30. def func_returnunicode():
  31. return "bar"
  32. def func_returnint():
  33. return 42
  34. def func_returnfloat():
  35. return 3.14
  36. def func_returnnull():
  37. return None
  38. def func_returnblob():
  39. return b"blob"
  40. def func_returnlonglong():
  41. return 1<<31
  42. def func_raiseexception():
  43. 5/0
  44. class AggrNoStep:
  45. def __init__(self):
  46. pass
  47. def finalize(self):
  48. return 1
  49. class AggrNoFinalize:
  50. def __init__(self):
  51. pass
  52. def step(self, x):
  53. pass
  54. class AggrExceptionInInit:
  55. def __init__(self):
  56. 5/0
  57. def step(self, x):
  58. pass
  59. def finalize(self):
  60. pass
  61. class AggrExceptionInStep:
  62. def __init__(self):
  63. pass
  64. def step(self, x):
  65. 5/0
  66. def finalize(self):
  67. return 42
  68. class AggrExceptionInFinalize:
  69. def __init__(self):
  70. pass
  71. def step(self, x):
  72. pass
  73. def finalize(self):
  74. 5/0
  75. class AggrCheckType:
  76. def __init__(self):
  77. self.val = None
  78. def step(self, whichType, val):
  79. theType = {"str": str, "int": int, "float": float, "None": type(None),
  80. "blob": bytes}
  81. self.val = int(theType[whichType] is type(val))
  82. def finalize(self):
  83. return self.val
  84. class AggrCheckTypes:
  85. def __init__(self):
  86. self.val = 0
  87. def step(self, whichType, *vals):
  88. theType = {"str": str, "int": int, "float": float, "None": type(None),
  89. "blob": bytes}
  90. for val in vals:
  91. self.val += int(theType[whichType] is type(val))
  92. def finalize(self):
  93. return self.val
  94. class AggrSum:
  95. def __init__(self):
  96. self.val = 0.0
  97. def step(self, val):
  98. self.val += val
  99. def finalize(self):
  100. return self.val
  101. class AggrText:
  102. def __init__(self):
  103. self.txt = ""
  104. def step(self, txt):
  105. self.txt = self.txt + txt
  106. def finalize(self):
  107. return self.txt
  108. class FunctionTests(unittest.TestCase):
  109. def setUp(self):
  110. self.con = sqlite.connect(":memory:")
  111. self.con.create_function("returntext", 0, func_returntext)
  112. self.con.create_function("returntextwithnull", 0, func_returntextwithnull)
  113. self.con.create_function("returnunicode", 0, func_returnunicode)
  114. self.con.create_function("returnint", 0, func_returnint)
  115. self.con.create_function("returnfloat", 0, func_returnfloat)
  116. self.con.create_function("returnnull", 0, func_returnnull)
  117. self.con.create_function("returnblob", 0, func_returnblob)
  118. self.con.create_function("returnlonglong", 0, func_returnlonglong)
  119. self.con.create_function("returnnan", 0, lambda: float("nan"))
  120. self.con.create_function("returntoolargeint", 0, lambda: 1 << 65)
  121. self.con.create_function("raiseexception", 0, func_raiseexception)
  122. self.con.create_function("isblob", 1, lambda x: isinstance(x, bytes))
  123. self.con.create_function("isnone", 1, lambda x: x is None)
  124. self.con.create_function("spam", -1, lambda *x: len(x))
  125. self.con.execute("create table test(t text)")
  126. def tearDown(self):
  127. self.con.close()
  128. def CheckFuncErrorOnCreate(self):
  129. with self.assertRaises(sqlite.OperationalError):
  130. self.con.create_function("bla", -100, lambda x: 2*x)
  131. def CheckFuncRefCount(self):
  132. def getfunc():
  133. def f():
  134. return 1
  135. return f
  136. f = getfunc()
  137. globals()["foo"] = f
  138. # self.con.create_function("reftest", 0, getfunc())
  139. self.con.create_function("reftest", 0, f)
  140. cur = self.con.cursor()
  141. cur.execute("select reftest()")
  142. def CheckFuncReturnText(self):
  143. cur = self.con.cursor()
  144. cur.execute("select returntext()")
  145. val = cur.fetchone()[0]
  146. self.assertEqual(type(val), str)
  147. self.assertEqual(val, "foo")
  148. def CheckFuncReturnTextWithNullChar(self):
  149. cur = self.con.cursor()
  150. res = cur.execute("select returntextwithnull()").fetchone()[0]
  151. self.assertEqual(type(res), str)
  152. self.assertEqual(res, "1\x002")
  153. def CheckFuncReturnUnicode(self):
  154. cur = self.con.cursor()
  155. cur.execute("select returnunicode()")
  156. val = cur.fetchone()[0]
  157. self.assertEqual(type(val), str)
  158. self.assertEqual(val, "bar")
  159. def CheckFuncReturnInt(self):
  160. cur = self.con.cursor()
  161. cur.execute("select returnint()")
  162. val = cur.fetchone()[0]
  163. self.assertEqual(type(val), int)
  164. self.assertEqual(val, 42)
  165. def CheckFuncReturnFloat(self):
  166. cur = self.con.cursor()
  167. cur.execute("select returnfloat()")
  168. val = cur.fetchone()[0]
  169. self.assertEqual(type(val), float)
  170. if val < 3.139 or val > 3.141:
  171. self.fail("wrong value")
  172. def CheckFuncReturnNull(self):
  173. cur = self.con.cursor()
  174. cur.execute("select returnnull()")
  175. val = cur.fetchone()[0]
  176. self.assertEqual(type(val), type(None))
  177. self.assertEqual(val, None)
  178. def CheckFuncReturnBlob(self):
  179. cur = self.con.cursor()
  180. cur.execute("select returnblob()")
  181. val = cur.fetchone()[0]
  182. self.assertEqual(type(val), bytes)
  183. self.assertEqual(val, b"blob")
  184. def CheckFuncReturnLongLong(self):
  185. cur = self.con.cursor()
  186. cur.execute("select returnlonglong()")
  187. val = cur.fetchone()[0]
  188. self.assertEqual(val, 1<<31)
  189. def CheckFuncReturnNaN(self):
  190. cur = self.con.cursor()
  191. cur.execute("select returnnan()")
  192. self.assertIsNone(cur.fetchone()[0])
  193. def CheckFuncReturnTooLargeInt(self):
  194. cur = self.con.cursor()
  195. with self.assertRaises(sqlite.OperationalError):
  196. self.con.execute("select returntoolargeint()")
  197. def CheckFuncException(self):
  198. cur = self.con.cursor()
  199. with self.assertRaises(sqlite.OperationalError) as cm:
  200. cur.execute("select raiseexception()")
  201. cur.fetchone()
  202. self.assertEqual(str(cm.exception), 'user-defined function raised exception')
  203. def CheckAnyArguments(self):
  204. cur = self.con.cursor()
  205. cur.execute("select spam(?, ?)", (1, 2))
  206. val = cur.fetchone()[0]
  207. self.assertEqual(val, 2)
  208. def CheckEmptyBlob(self):
  209. cur = self.con.execute("select isblob(x'')")
  210. self.assertTrue(cur.fetchone()[0])
  211. def CheckNaNFloat(self):
  212. cur = self.con.execute("select isnone(?)", (float("nan"),))
  213. # SQLite has no concept of nan; it is converted to NULL
  214. self.assertTrue(cur.fetchone()[0])
  215. def CheckTooLargeInt(self):
  216. err = "Python int too large to convert to SQLite INTEGER"
  217. self.assertRaisesRegex(OverflowError, err, self.con.execute,
  218. "select spam(?)", (1 << 65,))
  219. def CheckNonContiguousBlob(self):
  220. self.assertRaisesRegex(ValueError, "could not convert BLOB to buffer",
  221. self.con.execute, "select spam(?)",
  222. (memoryview(b"blob")[::2],))
  223. def CheckParamSurrogates(self):
  224. self.assertRaisesRegex(UnicodeEncodeError, "surrogates not allowed",
  225. self.con.execute, "select spam(?)",
  226. ("\ud803\ude6d",))
  227. def CheckFuncParams(self):
  228. results = []
  229. def append_result(arg):
  230. results.append((arg, type(arg)))
  231. self.con.create_function("test_params", 1, append_result)
  232. dataset = [
  233. (42, int),
  234. (-1, int),
  235. (1234567890123456789, int),
  236. (4611686018427387905, int), # 63-bit int with non-zero low bits
  237. (3.14, float),
  238. (float('inf'), float),
  239. ("text", str),
  240. ("1\x002", str),
  241. ("\u02e2q\u02e1\u2071\u1d57\u1d49", str),
  242. (b"blob", bytes),
  243. (bytearray(range(2)), bytes),
  244. (memoryview(b"blob"), bytes),
  245. (None, type(None)),
  246. ]
  247. for val, _ in dataset:
  248. cur = self.con.execute("select test_params(?)", (val,))
  249. cur.fetchone()
  250. self.assertEqual(dataset, results)
  251. # Regarding deterministic functions:
  252. #
  253. # Between 3.8.3 and 3.15.0, deterministic functions were only used to
  254. # optimize inner loops, so for those versions we can only test if the
  255. # sqlite machinery has factored out a call or not. From 3.15.0 and onward,
  256. # deterministic functions were permitted in WHERE clauses of partial
  257. # indices, which allows testing based on syntax, iso. the query optimizer.
  258. @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
  259. def CheckFuncNonDeterministic(self):
  260. mock = unittest.mock.Mock(return_value=None)
  261. self.con.create_function("nondeterministic", 0, mock, deterministic=False)
  262. if sqlite.sqlite_version_info < (3, 15, 0):
  263. self.con.execute("select nondeterministic() = nondeterministic()")
  264. self.assertEqual(mock.call_count, 2)
  265. else:
  266. with self.assertRaises(sqlite.OperationalError):
  267. self.con.execute("create index t on test(t) where nondeterministic() is not null")
  268. @unittest.skipIf(sqlite.sqlite_version_info < (3, 8, 3), "Requires SQLite 3.8.3 or higher")
  269. def CheckFuncDeterministic(self):
  270. mock = unittest.mock.Mock(return_value=None)
  271. self.con.create_function("deterministic", 0, mock, deterministic=True)
  272. if sqlite.sqlite_version_info < (3, 15, 0):
  273. self.con.execute("select deterministic() = deterministic()")
  274. self.assertEqual(mock.call_count, 1)
  275. else:
  276. try:
  277. self.con.execute("create index t on test(t) where deterministic() is not null")
  278. except sqlite.OperationalError:
  279. self.fail("Unexpected failure while creating partial index")
  280. @unittest.skipIf(sqlite.sqlite_version_info >= (3, 8, 3), "SQLite < 3.8.3 needed")
  281. def CheckFuncDeterministicNotSupported(self):
  282. with self.assertRaises(sqlite.NotSupportedError):
  283. self.con.create_function("deterministic", 0, int, deterministic=True)
  284. def CheckFuncDeterministicKeywordOnly(self):
  285. with self.assertRaises(TypeError):
  286. self.con.create_function("deterministic", 0, int, True)
  287. class AggregateTests(unittest.TestCase):
  288. def setUp(self):
  289. self.con = sqlite.connect(":memory:")
  290. cur = self.con.cursor()
  291. cur.execute("""
  292. create table test(
  293. t text,
  294. i integer,
  295. f float,
  296. n,
  297. b blob
  298. )
  299. """)
  300. cur.execute("insert into test(t, i, f, n, b) values (?, ?, ?, ?, ?)",
  301. ("foo", 5, 3.14, None, memoryview(b"blob"),))
  302. self.con.create_aggregate("nostep", 1, AggrNoStep)
  303. self.con.create_aggregate("nofinalize", 1, AggrNoFinalize)
  304. self.con.create_aggregate("excInit", 1, AggrExceptionInInit)
  305. self.con.create_aggregate("excStep", 1, AggrExceptionInStep)
  306. self.con.create_aggregate("excFinalize", 1, AggrExceptionInFinalize)
  307. self.con.create_aggregate("checkType", 2, AggrCheckType)
  308. self.con.create_aggregate("checkTypes", -1, AggrCheckTypes)
  309. self.con.create_aggregate("mysum", 1, AggrSum)
  310. self.con.create_aggregate("aggtxt", 1, AggrText)
  311. def tearDown(self):
  312. #self.cur.close()
  313. #self.con.close()
  314. pass
  315. def CheckAggrErrorOnCreate(self):
  316. with self.assertRaises(sqlite.OperationalError):
  317. self.con.create_function("bla", -100, AggrSum)
  318. def CheckAggrNoStep(self):
  319. cur = self.con.cursor()
  320. with self.assertRaises(AttributeError) as cm:
  321. cur.execute("select nostep(t) from test")
  322. self.assertEqual(str(cm.exception), "'AggrNoStep' object has no attribute 'step'")
  323. def CheckAggrNoFinalize(self):
  324. cur = self.con.cursor()
  325. with self.assertRaises(sqlite.OperationalError) as cm:
  326. cur.execute("select nofinalize(t) from test")
  327. val = cur.fetchone()[0]
  328. self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
  329. def CheckAggrExceptionInInit(self):
  330. cur = self.con.cursor()
  331. with self.assertRaises(sqlite.OperationalError) as cm:
  332. cur.execute("select excInit(t) from test")
  333. val = cur.fetchone()[0]
  334. self.assertEqual(str(cm.exception), "user-defined aggregate's '__init__' method raised error")
  335. def CheckAggrExceptionInStep(self):
  336. cur = self.con.cursor()
  337. with self.assertRaises(sqlite.OperationalError) as cm:
  338. cur.execute("select excStep(t) from test")
  339. val = cur.fetchone()[0]
  340. self.assertEqual(str(cm.exception), "user-defined aggregate's 'step' method raised error")
  341. def CheckAggrExceptionInFinalize(self):
  342. cur = self.con.cursor()
  343. with self.assertRaises(sqlite.OperationalError) as cm:
  344. cur.execute("select excFinalize(t) from test")
  345. val = cur.fetchone()[0]
  346. self.assertEqual(str(cm.exception), "user-defined aggregate's 'finalize' method raised error")
  347. def CheckAggrCheckParamStr(self):
  348. cur = self.con.cursor()
  349. cur.execute("select checkTypes('str', ?, ?)", ("foo", str()))
  350. val = cur.fetchone()[0]
  351. self.assertEqual(val, 2)
  352. def CheckAggrCheckParamInt(self):
  353. cur = self.con.cursor()
  354. cur.execute("select checkType('int', ?)", (42,))
  355. val = cur.fetchone()[0]
  356. self.assertEqual(val, 1)
  357. def CheckAggrCheckParamsInt(self):
  358. cur = self.con.cursor()
  359. cur.execute("select checkTypes('int', ?, ?)", (42, 24))
  360. val = cur.fetchone()[0]
  361. self.assertEqual(val, 2)
  362. def CheckAggrCheckParamFloat(self):
  363. cur = self.con.cursor()
  364. cur.execute("select checkType('float', ?)", (3.14,))
  365. val = cur.fetchone()[0]
  366. self.assertEqual(val, 1)
  367. def CheckAggrCheckParamNone(self):
  368. cur = self.con.cursor()
  369. cur.execute("select checkType('None', ?)", (None,))
  370. val = cur.fetchone()[0]
  371. self.assertEqual(val, 1)
  372. def CheckAggrCheckParamBlob(self):
  373. cur = self.con.cursor()
  374. cur.execute("select checkType('blob', ?)", (memoryview(b"blob"),))
  375. val = cur.fetchone()[0]
  376. self.assertEqual(val, 1)
  377. def CheckAggrCheckAggrSum(self):
  378. cur = self.con.cursor()
  379. cur.execute("delete from test")
  380. cur.executemany("insert into test(i) values (?)", [(10,), (20,), (30,)])
  381. cur.execute("select mysum(i) from test")
  382. val = cur.fetchone()[0]
  383. self.assertEqual(val, 60)
  384. def CheckAggrText(self):
  385. cur = self.con.cursor()
  386. for txt in ["foo", "1\x002"]:
  387. with self.subTest(txt=txt):
  388. cur.execute("select aggtxt(?) from test", (txt,))
  389. val = cur.fetchone()[0]
  390. self.assertEqual(val, txt)
  391. class AuthorizerTests(unittest.TestCase):
  392. @staticmethod
  393. def authorizer_cb(action, arg1, arg2, dbname, source):
  394. if action != sqlite.SQLITE_SELECT:
  395. return sqlite.SQLITE_DENY
  396. if arg2 == 'c2' or arg1 == 't2':
  397. return sqlite.SQLITE_DENY
  398. return sqlite.SQLITE_OK
  399. def setUp(self):
  400. self.con = sqlite.connect(":memory:")
  401. self.con.executescript("""
  402. create table t1 (c1, c2);
  403. create table t2 (c1, c2);
  404. insert into t1 (c1, c2) values (1, 2);
  405. insert into t2 (c1, c2) values (4, 5);
  406. """)
  407. # For our security test:
  408. self.con.execute("select c2 from t2")
  409. self.con.set_authorizer(self.authorizer_cb)
  410. def tearDown(self):
  411. pass
  412. def test_table_access(self):
  413. with self.assertRaises(sqlite.DatabaseError) as cm:
  414. self.con.execute("select * from t2")
  415. self.assertIn('prohibited', str(cm.exception))
  416. def test_column_access(self):
  417. with self.assertRaises(sqlite.DatabaseError) as cm:
  418. self.con.execute("select c2 from t1")
  419. self.assertIn('prohibited', str(cm.exception))
  420. class AuthorizerRaiseExceptionTests(AuthorizerTests):
  421. @staticmethod
  422. def authorizer_cb(action, arg1, arg2, dbname, source):
  423. if action != sqlite.SQLITE_SELECT:
  424. raise ValueError
  425. if arg2 == 'c2' or arg1 == 't2':
  426. raise ValueError
  427. return sqlite.SQLITE_OK
  428. class AuthorizerIllegalTypeTests(AuthorizerTests):
  429. @staticmethod
  430. def authorizer_cb(action, arg1, arg2, dbname, source):
  431. if action != sqlite.SQLITE_SELECT:
  432. return 0.0
  433. if arg2 == 'c2' or arg1 == 't2':
  434. return 0.0
  435. return sqlite.SQLITE_OK
  436. class AuthorizerLargeIntegerTests(AuthorizerTests):
  437. @staticmethod
  438. def authorizer_cb(action, arg1, arg2, dbname, source):
  439. if action != sqlite.SQLITE_SELECT:
  440. return 2**32
  441. if arg2 == 'c2' or arg1 == 't2':
  442. return 2**32
  443. return sqlite.SQLITE_OK
  444. def suite():
  445. function_suite = unittest.makeSuite(FunctionTests, "Check")
  446. aggregate_suite = unittest.makeSuite(AggregateTests, "Check")
  447. authorizer_suite = unittest.makeSuite(AuthorizerTests)
  448. return unittest.TestSuite((
  449. function_suite,
  450. aggregate_suite,
  451. authorizer_suite,
  452. unittest.makeSuite(AuthorizerRaiseExceptionTests),
  453. unittest.makeSuite(AuthorizerIllegalTypeTests),
  454. unittest.makeSuite(AuthorizerLargeIntegerTests),
  455. ))
  456. def test():
  457. runner = unittest.TextTestRunner()
  458. runner.run(suite())
  459. if __name__ == "__main__":
  460. test()