123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373 |
- from __future__ import absolute_import
- import sys, os, re, inspect
- import imp
- try:
- import hashlib
- except ImportError:
- import md5 as hashlib
- from distutils.core import Distribution, Extension
- from distutils.command.build_ext import build_ext
- import Cython
- from ..Compiler.Main import Context, CompilationOptions, default_options
- from ..Compiler.ParseTreeTransforms import (CythonTransform,
- SkipDeclarations, AnalyseDeclarationsTransform, EnvTransform)
- from ..Compiler.TreeFragment import parse_from_strings
- from ..Compiler.StringEncoding import _unicode
- from .Dependencies import strip_string_literals, cythonize, cached_function
- from ..Compiler import Pipeline, Nodes
- from ..Utils import get_cython_cache_dir
- import cython as cython_module
- IS_PY3 = sys.version_info >= (3, 0)
- # A utility function to convert user-supplied ASCII strings to unicode.
- if sys.version_info[0] < 3:
- def to_unicode(s):
- if isinstance(s, bytes):
- return s.decode('ascii')
- else:
- return s
- else:
- to_unicode = lambda x: x
- if sys.version_info < (3, 5):
- import imp
- def load_dynamic(name, module_path):
- return imp.load_dynamic(name, module_path)
- else:
- import importlib.util as _importlib_util
- def load_dynamic(name, module_path):
- spec = _importlib_util.spec_from_file_location(name, module_path)
- module = _importlib_util.module_from_spec(spec)
- # sys.modules[name] = module
- spec.loader.exec_module(module)
- return module
- class UnboundSymbols(EnvTransform, SkipDeclarations):
- def __init__(self):
- CythonTransform.__init__(self, None)
- self.unbound = set()
- def visit_NameNode(self, node):
- if not self.current_env().lookup(node.name):
- self.unbound.add(node.name)
- return node
- def __call__(self, node):
- super(UnboundSymbols, self).__call__(node)
- return self.unbound
- @cached_function
- def unbound_symbols(code, context=None):
- code = to_unicode(code)
- if context is None:
- context = Context([], default_options)
- from ..Compiler.ParseTreeTransforms import AnalyseDeclarationsTransform
- tree = parse_from_strings('(tree fragment)', code)
- for phase in Pipeline.create_pipeline(context, 'pyx'):
- if phase is None:
- continue
- tree = phase(tree)
- if isinstance(phase, AnalyseDeclarationsTransform):
- break
- try:
- import builtins
- except ImportError:
- import __builtin__ as builtins
- return tuple(UnboundSymbols()(tree) - set(dir(builtins)))
- def unsafe_type(arg, context=None):
- py_type = type(arg)
- if py_type is int:
- return 'long'
- else:
- return safe_type(arg, context)
- def safe_type(arg, context=None):
- py_type = type(arg)
- if py_type in (list, tuple, dict, str):
- return py_type.__name__
- elif py_type is complex:
- return 'double complex'
- elif py_type is float:
- return 'double'
- elif py_type is bool:
- return 'bint'
- elif 'numpy' in sys.modules and isinstance(arg, sys.modules['numpy'].ndarray):
- return 'numpy.ndarray[numpy.%s_t, ndim=%s]' % (arg.dtype.name, arg.ndim)
- else:
- for base_type in py_type.__mro__:
- if base_type.__module__ in ('__builtin__', 'builtins'):
- return 'object'
- module = context.find_module(base_type.__module__, need_pxd=False)
- if module:
- entry = module.lookup(base_type.__name__)
- if entry.is_type:
- return '%s.%s' % (base_type.__module__, base_type.__name__)
- return 'object'
- def _get_build_extension():
- dist = Distribution()
- # Ensure the build respects distutils configuration by parsing
- # the configuration files
- config_files = dist.find_config_files()
- dist.parse_config_files(config_files)
- build_extension = build_ext(dist)
- build_extension.finalize_options()
- return build_extension
- @cached_function
- def _create_context(cython_include_dirs):
- return Context(list(cython_include_dirs), default_options)
- _cython_inline_cache = {}
- _cython_inline_default_context = _create_context(('.',))
- def _populate_unbound(kwds, unbound_symbols, locals=None, globals=None):
- for symbol in unbound_symbols:
- if symbol not in kwds:
- if locals is None or globals is None:
- calling_frame = inspect.currentframe().f_back.f_back.f_back
- if locals is None:
- locals = calling_frame.f_locals
- if globals is None:
- globals = calling_frame.f_globals
- if symbol in locals:
- kwds[symbol] = locals[symbol]
- elif symbol in globals:
- kwds[symbol] = globals[symbol]
- else:
- print("Couldn't find %r" % symbol)
- def _inline_key(orig_code, arg_sigs, language_level):
- key = orig_code, arg_sigs, sys.version_info, sys.executable, language_level, Cython.__version__
- return hashlib.sha1(_unicode(key).encode('utf-8')).hexdigest()
- def cython_inline(code, get_type=unsafe_type,
- lib_dir=os.path.join(get_cython_cache_dir(), 'inline'),
- cython_include_dirs=None, cython_compiler_directives=None,
- force=False, quiet=False, locals=None, globals=None, language_level=None, **kwds):
- if get_type is None:
- get_type = lambda x: 'object'
- ctx = _create_context(tuple(cython_include_dirs)) if cython_include_dirs else _cython_inline_default_context
- cython_compiler_directives = dict(cython_compiler_directives or {})
- if language_level is None and 'language_level' not in cython_compiler_directives:
- language_level = '3str'
- if language_level is not None:
- cython_compiler_directives['language_level'] = language_level
- # Fast path if this has been called in this session.
- _unbound_symbols = _cython_inline_cache.get(code)
- if _unbound_symbols is not None:
- _populate_unbound(kwds, _unbound_symbols, locals, globals)
- args = sorted(kwds.items())
- arg_sigs = tuple([(get_type(value, ctx), arg) for arg, value in args])
- key_hash = _inline_key(code, arg_sigs, language_level)
- invoke = _cython_inline_cache.get((code, arg_sigs, key_hash))
- if invoke is not None:
- arg_list = [arg[1] for arg in args]
- return invoke(*arg_list)
- orig_code = code
- code = to_unicode(code)
- code, literals = strip_string_literals(code)
- code = strip_common_indent(code)
- if locals is None:
- locals = inspect.currentframe().f_back.f_back.f_locals
- if globals is None:
- globals = inspect.currentframe().f_back.f_back.f_globals
- try:
- _cython_inline_cache[orig_code] = _unbound_symbols = unbound_symbols(code)
- _populate_unbound(kwds, _unbound_symbols, locals, globals)
- except AssertionError:
- if not quiet:
- # Parsing from strings not fully supported (e.g. cimports).
- print("Could not parse code as a string (to extract unbound symbols).")
- cimports = []
- for name, arg in list(kwds.items()):
- if arg is cython_module:
- cimports.append('\ncimport cython as %s' % name)
- del kwds[name]
- arg_names = sorted(kwds)
- arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
- key_hash = _inline_key(orig_code, arg_sigs, language_level)
- module_name = "_cython_inline_" + key_hash
- if module_name in sys.modules:
- module = sys.modules[module_name]
- else:
- build_extension = None
- if cython_inline.so_ext is None:
- # Figure out and cache current extension suffix
- build_extension = _get_build_extension()
- cython_inline.so_ext = build_extension.get_ext_filename('')
- module_path = os.path.join(lib_dir, module_name + cython_inline.so_ext)
- if not os.path.exists(lib_dir):
- os.makedirs(lib_dir)
- if force or not os.path.isfile(module_path):
- cflags = []
- c_include_dirs = []
- qualified = re.compile(r'([.\w]+)[.]')
- for type, _ in arg_sigs:
- m = qualified.match(type)
- if m:
- cimports.append('\ncimport %s' % m.groups()[0])
- # one special case
- if m.groups()[0] == 'numpy':
- import numpy
- c_include_dirs.append(numpy.get_include())
- # cflags.append('-Wno-unused')
- module_body, func_body = extract_func_code(code)
- params = ', '.join(['%s %s' % a for a in arg_sigs])
- module_code = """
- %(module_body)s
- %(cimports)s
- def __invoke(%(params)s):
- %(func_body)s
- return locals()
- """ % {'cimports': '\n'.join(cimports),
- 'module_body': module_body,
- 'params': params,
- 'func_body': func_body }
- for key, value in literals.items():
- module_code = module_code.replace(key, value)
- pyx_file = os.path.join(lib_dir, module_name + '.pyx')
- fh = open(pyx_file, 'w')
- try:
- fh.write(module_code)
- finally:
- fh.close()
- extension = Extension(
- name = module_name,
- sources = [pyx_file],
- include_dirs = c_include_dirs,
- extra_compile_args = cflags)
- if build_extension is None:
- build_extension = _get_build_extension()
- build_extension.extensions = cythonize(
- [extension],
- include_path=cython_include_dirs or ['.'],
- compiler_directives=cython_compiler_directives,
- quiet=quiet)
- build_extension.build_temp = os.path.dirname(pyx_file)
- build_extension.build_lib = lib_dir
- build_extension.run()
- module = load_dynamic(module_name, module_path)
- _cython_inline_cache[orig_code, arg_sigs, key_hash] = module.__invoke
- arg_list = [kwds[arg] for arg in arg_names]
- return module.__invoke(*arg_list)
- # Cached suffix used by cython_inline above. None should get
- # overridden with actual value upon the first cython_inline invocation
- cython_inline.so_ext = None
- _find_non_space = re.compile('[^ ]').search
- def strip_common_indent(code):
- min_indent = None
- lines = code.splitlines()
- for line in lines:
- match = _find_non_space(line)
- if not match:
- continue # blank
- indent = match.start()
- if line[indent] == '#':
- continue # comment
- if min_indent is None or min_indent > indent:
- min_indent = indent
- for ix, line in enumerate(lines):
- match = _find_non_space(line)
- if not match or not line or line[indent:indent+1] == '#':
- continue
- lines[ix] = line[min_indent:]
- return '\n'.join(lines)
- module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimport)|(from .+ import +[*]))')
- def extract_func_code(code):
- module = []
- function = []
- current = function
- code = code.replace('\t', ' ')
- lines = code.split('\n')
- for line in lines:
- if not line.startswith(' '):
- if module_statement.match(line):
- current = module
- else:
- current = function
- current.append(line)
- return '\n'.join(module), ' ' + '\n '.join(function)
- try:
- from inspect import getcallargs
- except ImportError:
- def getcallargs(func, *arg_values, **kwd_values):
- all = {}
- args, varargs, kwds, defaults = inspect.getargspec(func)
- if varargs is not None:
- all[varargs] = arg_values[len(args):]
- for name, value in zip(args, arg_values):
- all[name] = value
- for name, value in list(kwd_values.items()):
- if name in args:
- if name in all:
- raise TypeError("Duplicate argument %s" % name)
- all[name] = kwd_values.pop(name)
- if kwds is not None:
- all[kwds] = kwd_values
- elif kwd_values:
- raise TypeError("Unexpected keyword arguments: %s" % list(kwd_values))
- if defaults is None:
- defaults = ()
- first_default = len(args) - len(defaults)
- for ix, name in enumerate(args):
- if name not in all:
- if ix >= first_default:
- all[name] = defaults[ix - first_default]
- else:
- raise TypeError("Missing argument: %s" % name)
- return all
- def get_body(source):
- ix = source.index(':')
- if source[:5] == 'lambda':
- return "return %s" % source[ix+1:]
- else:
- return source[ix+1:]
- # Lots to be done here... It would be especially cool if compiled functions
- # could invoke each other quickly.
- class RuntimeCompiledFunction(object):
- def __init__(self, f):
- self._f = f
- self._body = get_body(inspect.getsource(f))
- def __call__(self, *args, **kwds):
- all = getcallargs(self._f, *args, **kwds)
- if IS_PY3:
- return cython_inline(self._body, locals=self._f.__globals__, globals=self._f.__globals__, **all)
- else:
- return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
|