Inline.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373
  1. from __future__ import absolute_import
  2. import sys, os, re, inspect
  3. import imp
  4. try:
  5. import hashlib
  6. except ImportError:
  7. import md5 as hashlib
  8. from distutils.core import Distribution, Extension
  9. from distutils.command.build_ext import build_ext
  10. import Cython
  11. from ..Compiler.Main import Context, CompilationOptions, default_options
  12. from ..Compiler.ParseTreeTransforms import (CythonTransform,
  13. SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform)
  14. from ..Compiler.TreeFragment import parse_from_strings
  15. from ..Compiler.StringEncoding import _unicode
  16. from .Dependencies import strip_string_literals, cythonize, cached_function
  17. from ..Compiler import Pipeline, Nodes
  18. from ..Utils import get_cython_cache_dir
  19. import cython as cython_module
  20. IS_PY3 = sys.version_info >= (3, 0)
  21. # A utility function to convert user-supplied ASCII strings to unicode.
  22. if sys.version_info[0] < 3:
  23. def to_unicode(s):
  24. if isinstance(s, bytes):
  25. return s.decode('ascii')
  26. else:
  27. return s
  28. else:
  29. to_unicode = lambda x: x
  30. if sys.version_info < (3, 5):
  31. import imp
  32. def load_dynamic(name, module_path):
  33. return imp.load_dynamic(name, module_path)
  34. else:
  35. import importlib.util as _importlib_util
  36. def load_dynamic(name, module_path):
  37. spec = _importlib_util.spec_from_file_location(name, module_path)
  38. module = _importlib_util.module_from_spec(spec)
  39. # sys.modules[name] = module
  40. spec.loader.exec_module(module)
  41. return module
  42. class UnboundSymbols(EnvTransform, SkipDeclarations):
  43. def __init__(self):
  44. CythonTransform.__init__(self, None)
  45. self.unbound = set()
  46. def visit_NameNode(self, node):
  47. if not self.current_env().lookup(node.name):
  48. self.unbound.add(node.name)
  49. return node
  50. def __call__(self, node):
  51. super(UnboundSymbols, self).__call__(node)
  52. return self.unbound
  53. @cached_function
  54. def unbound_symbols(code, context=None):
  55. code = to_unicode(code)
  56. if context is None:
  57. context = Context([], default_options)
  58. from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
  59. tree = parse_from_strings('(tree fragment)', code)
  60. for phase in Pipeline.create_pipeline(context, 'pyx'):
  61. if phase is None:
  62. continue
  63. tree = phase(tree)
  64. if isinstance(phase, AnalyseDeclarationsTransform):
  65. break
  66. try:
  67. import builtins
  68. except ImportError:
  69. import __builtin__ as builtins
  70. return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
  71. def unsafe_type(arg, context=None):
  72. py_type = type(arg)
  73. if py_type is int:
  74. return 'long'
  75. else:
  76. return safe_type(arg, context)
  77. def safe_type(arg, context=None):
  78. py_type = type(arg)
  79. if py_type in (list, tuple, dict, str):
  80. return py_type.__name__
  81. elif py_type is complex:
  82. return 'double complex'
  83. elif py_type is float:
  84. return 'double'
  85. elif py_type is bool:
  86. return 'bint'
  87. elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
  88. return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
  89. else:
  90. for base_type in py_type.__mro__:
  91. if base_type.__module__ in ('__builtin__', 'builtins'):
  92. return 'object'
  93. module = context.find_module(base_type.__module__, need_pxd=False)
  94. if module:
  95. entry = module.lookup(base_type.__name__)
  96. if entry.is_type:
  97. return '%s.%s' % (base_type.__module__, base_type.__name__)
  98. return 'object'
  99. def _get_build_extension():
  100. dist = Distribution()
  101. # Ensure the build respects distutils configuration by parsing
  102. # the configuration files
  103. config_files = dist.find_config_files()
  104. dist.parse_config_files(config_files)
  105. build_extension = build_ext(dist)
  106. build_extension.finalize_options()
  107. return build_extension
  108. @cached_function
  109. def _create_context(cython_include_dirs):
  110. return Context(list(cython_include_dirs), default_options)
  111. _cython_inline_cache = {}
  112. _cython_inline_default_context = _create_context(('.',))
  113. def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
  114. for symbol in unbound_symbols:
  115. if symbol not in kwds:
  116. if locals is None or globals is None:
  117. calling_frame = inspect.currentframe().f_back.f_back.f_back
  118. if locals is None:
  119. locals = calling_frame.f_locals
  120. if globals is None:
  121. globals = calling_frame.f_globals
  122. if symbol in locals:
  123. kwds[symbol] = locals[symbol]
  124. elif symbol in globals:
  125. kwds[symbol] = globals[symbol]
  126. else:
  127. print("Couldn't find %r" % symbol)
  128. def _inline_key(orig_code, arg_sigs, language_level):
  129. key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
  130. return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
  131. def cython_inline(code, get_type=unsafe_type,
  132. lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
  133. cython_include_dirs=None, cython_compiler_directives=None,
  134. force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
  135. if get_type is None:
  136. get_type = lambda x: 'object'
  137. ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
  138. cython_compiler_directives = dict(cython_compiler_directives or {})
  139. if language_level is None and 'language_level' not in cython_compiler_directives:
  140. language_level = '3str'
  141. if language_level is not None:
  142. cython_compiler_directives['language_level'] = language_level
  143. # Fast path if this has been called in this session.
  144. _unbound_symbols = _cython_inline_cache.get(code)
  145. if _unbound_symbols is not None:
  146. _populate_unbound(kwds, _unbound_symbols, locals, globals)
  147. args = sorted(kwds.items())
  148. arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
  149. key_hash = _inline_key(code, arg_sigs, language_level)
  150. invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
  151. if invoke is not None:
  152. arg_list = [arg[1] for arg in args]
  153. return invoke(*arg_list)
  154. orig_code = code
  155. code = to_unicode(code)
  156. code, literals = strip_string_literals(code)
  157. code = strip_common_indent(code)
  158. if locals is None:
  159. locals = inspect.currentframe().f_back.f_back.f_locals
  160. if globals is None:
  161. globals = inspect.currentframe().f_back.f_back.f_globals
  162. try:
  163. _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
  164. _populate_unbound(kwds, _unbound_symbols, locals, globals)
  165. except AssertionError:
  166. if not quiet:
  167. # Parsing from strings not fully supported (e.g. cimports).
  168. print("Could not parse code as a string (to extract unbound symbols).")
  169. cimports = []
  170. for name, arg in list(kwds.items()):
  171. if arg is cython_module:
  172. cimports.append('\ncimport cython as %s' % name)
  173. del kwds[name]
  174. arg_names = sorted(kwds)
  175. arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
  176. key_hash = _inline_key(orig_code, arg_sigs, language_level)
  177. module_name = "_cython_inline_" + key_hash
  178. if module_name in sys.modules:
  179. module = sys.modules[module_name]
  180. else:
  181. build_extension = None
  182. if cython_inline.so_ext is None:
  183. # Figure out and cache current extension suffix
  184. build_extension = _get_build_extension()
  185. cython_inline.so_ext = build_extension.get_ext_filename('')
  186. module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
  187. if not os.path.exists(lib_dir):
  188. os.makedirs(lib_dir)
  189. if force or not os.path.isfile(module_path):
  190. cflags = []
  191. c_include_dirs = []
  192. qualified = re.compile(r'([.\w]+)[.]')
  193. for type, _ in arg_sigs:
  194. m = qualified.match(type)
  195. if m:
  196. cimports.append('\ncimport %s' % m.groups()[0])
  197. # one special case
  198. if m.groups()[0] == 'numpy':
  199. import numpy
  200. c_include_dirs.append(numpy.get_include())
  201. # cflags.append('-Wno-unused')
  202. module_body, func_body = extract_func_code(code)
  203. params = ', '.join(['%s %s' % a for a in arg_sigs])
  204. module_code = """
  205. %(module_body)s
  206. %(cimports)s
  207. def __invoke(%(params)s):
  208. %(func_body)s
  209. return locals()
  210. """ % {'cimports': '\n'.join(cimports),
  211. 'module_body': module_body,
  212. 'params': params,
  213. 'func_body': func_body }
  214. for key, value in literals.items():
  215. module_code = module_code.replace(key, value)
  216. pyx_file = os.path.join(lib_dir, module_name + '.pyx')
  217. fh = open(pyx_file, 'w')
  218. try:
  219. fh.write(module_code)
  220. finally:
  221. fh.close()
  222. extension = Extension(
  223. name = module_name,
  224. sources = [pyx_file],
  225. include_dirs = c_include_dirs,
  226. extra_compile_args = cflags)
  227. if build_extension is None:
  228. build_extension = _get_build_extension()
  229. build_extension.extensions = cythonize(
  230. [extension],
  231. include_path=cython_include_dirs or ['.'],
  232. compiler_directives=cython_compiler_directives,
  233. quiet=quiet)
  234. build_extension.build_temp = os.path.dirname(pyx_file)
  235. build_extension.build_lib = lib_dir
  236. build_extension.run()
  237. module = load_dynamic(module_name, module_path)
  238. _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
  239. arg_list = [kwds[arg] for arg in arg_names]
  240. return module.__invoke(*arg_list)
  241. # Cached suffix used by cython_inline above. None should get
  242. # overridden with actual value upon the first cython_inline invocation
  243. cython_inline.so_ext = None
  244. _find_non_space = re.compile('[^ ]').search
  245. def strip_common_indent(code):
  246. min_indent = None
  247. lines = code.splitlines()
  248. for line in lines:
  249. match = _find_non_space(line)
  250. if not match:
  251. continue # blank
  252. indent = match.start()
  253. if line[indent] == '#':
  254. continue # comment
  255. if min_indent is None or min_indent > indent:
  256. min_indent = indent
  257. for ix, line in enumerate(lines):
  258. match = _find_non_space(line)
  259. if not match or not line or line[indent:indent+1] == '#':
  260. continue
  261. lines[ix] = line[min_indent:]
  262. return '\n'.join(lines)
  263. module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
  264. def extract_func_code(code):
  265. module = []
  266. function = []
  267. current = function
  268. code = code.replace('\t', ' ')
  269. lines = code.split('\n')
  270. for line in lines:
  271. if not line.startswith(' '):
  272. if module_statement.match(line):
  273. current = module
  274. else:
  275. current = function
  276. current.append(line)
  277. return '\n'.join(module), ' ' + '\n '.join(function)
  278. try:
  279. from inspect import getcallargs
  280. except ImportError:
  281. def getcallargs(func, *arg_values, **kwd_values):
  282. all = {}
  283. args, varargs, kwds, defaults = inspect.getargspec(func)
  284. if varargs is not None:
  285. all[varargs] = arg_values[len(args):]
  286. for name, value in zip(args, arg_values):
  287. all[name] = value
  288. for name, value in list(kwd_values.items()):
  289. if name in args:
  290. if name in all:
  291. raise TypeError("Duplicate argument %s" % name)
  292. all[name] = kwd_values.pop(name)
  293. if kwds is not None:
  294. all[kwds] = kwd_values
  295. elif kwd_values:
  296. raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
  297. if defaults is None:
  298. defaults = ()
  299. first_default = len(args) - len(defaults)
  300. for ix, name in enumerate(args):
  301. if name not in all:
  302. if ix >= first_default:
  303. all[name] = defaults[ix - first_default]
  304. else:
  305. raise TypeError("Missing argument: %s" % name)
  306. return all
  307. def get_body(source):
  308. ix = source.index(':')
  309. if source[:5] == 'lambda':
  310. return "return %s" % source[ix+1:]
  311. else:
  312. return source[ix+1:]
  313. # Lots to be done here... It would be especially cool if compiled functions
  314. # could invoke each other quickly.
  315. class RuntimeCompiledFunction(object):
  316. def __init__(self, f):
  317. self._f = f
  318. self._body = get_body(inspect.getsource(f))
  319. def __call__(self, *args, **kwds):
  320. all = getcallargs(self._f, *args, **kwds)
  321. if IS_PY3:
  322. return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
  323. else:
  324. return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)