TestInline.py 2.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596
  1. import os, tempfile
  2. from Cython.Shadow import inline
  3. from Cython.Build.Inline import safe_type
  4. from Cython.TestUtils import CythonTest
  5. try:
  6. import numpy
  7. has_numpy = True
  8. except:
  9. has_numpy = False
  10. test_kwds = dict(force=True, quiet=True)
  11. global_value = 100
  12. class TestInline(CythonTest):
  13. def setUp(self):
  14. CythonTest.setUp(self)
  15. self.test_kwds = dict(test_kwds)
  16. if os.path.isdir('TEST_TMP'):
  17. lib_dir = os.path.join('TEST_TMP','inline')
  18. else:
  19. lib_dir = tempfile.mkdtemp(prefix='cython_inline_')
  20. self.test_kwds['lib_dir'] = lib_dir
  21. def test_simple(self):
  22. self.assertEquals(inline("return 1+2", **self.test_kwds), 3)
  23. def test_types(self):
  24. self.assertEquals(inline("""
  25. cimport cython
  26. return cython.typeof(a), cython.typeof(b)
  27. """, a=1.0, b=[], **self.test_kwds), ('double', 'list object'))
  28. def test_locals(self):
  29. a = 1
  30. b = 2
  31. self.assertEquals(inline("return a+b", **self.test_kwds), 3)
  32. def test_globals(self):
  33. self.assertEquals(inline("return global_value + 1", **self.test_kwds), global_value + 1)
  34. def test_no_return(self):
  35. self.assertEquals(inline("""
  36. a = 1
  37. cdef double b = 2
  38. cdef c = []
  39. """, **self.test_kwds), dict(a=1, b=2.0, c=[]))
  40. def test_def_node(self):
  41. foo = inline("def foo(x): return x * x", **self.test_kwds)['foo']
  42. self.assertEquals(foo(7), 49)
  43. def test_class_ref(self):
  44. class Type(object):
  45. pass
  46. tp = inline("Type")['Type']
  47. self.assertEqual(tp, Type)
  48. def test_pure(self):
  49. import cython as cy
  50. b = inline("""
  51. b = cy.declare(float, a)
  52. c = cy.declare(cy.pointer(cy.float), &b)
  53. return b
  54. """, a=3, **self.test_kwds)
  55. self.assertEquals(type(b), float)
  56. def test_compiler_directives(self):
  57. self.assertEqual(
  58. inline('return sum(x)',
  59. x=[1, 2, 3],
  60. cython_compiler_directives={'boundscheck': False}),
  61. 6
  62. )
  63. def test_lang_version(self):
  64. # GH-3419. Caching for inline code didn't always respect compiler directives.
  65. inline_divcode = "def f(int a, int b): return a/b"
  66. self.assertEqual(
  67. inline(inline_divcode, language_level=2)['f'](5,2),
  68. 2
  69. )
  70. self.assertEqual(
  71. inline(inline_divcode, language_level=3)['f'](5,2),
  72. 2.5
  73. )
  74. if has_numpy:
  75. def test_numpy(self):
  76. import numpy
  77. a = numpy.ndarray((10, 20))
  78. a[0,0] = 10
  79. self.assertEquals(safe_type(a), 'numpy.ndarray[numpy.float64_t, ndim=2]')
  80. self.assertEquals(inline("return a[0,0]", a=a, **self.test_kwds), 10.0)