Pipeline.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369
  1. from __future__ import absolute_import
  2. import itertools
  3. from time import time
  4. from . import Errors
  5. from . import DebugFlags
  6. from . import Options
  7. from .Errors import CompileError, InternalError, AbortError
  8. from . import Naming
  9. #
  10. # Really small pipeline stages
  11. #
  12. def dumptree(t):
  13. # For quick debugging in pipelines
  14. print(t.dump())
  15. return t
  16. def abort_on_errors(node):
  17. # Stop the pipeline if there are any errors.
  18. if Errors.num_errors != 0:
  19. raise AbortError("pipeline break")
  20. return node
  21. def parse_stage_factory(context):
  22. def parse(compsrc):
  23. source_desc = compsrc.source_desc
  24. full_module_name = compsrc.full_module_name
  25. initial_pos = (source_desc, 1, 0)
  26. saved_cimport_from_pyx, Options.cimport_from_pyx = Options.cimport_from_pyx, False
  27. scope = context.find_module(full_module_name, pos = initial_pos, need_pxd = 0)
  28. Options.cimport_from_pyx = saved_cimport_from_pyx
  29. tree = context.parse(source_desc, scope, pxd = 0, full_module_name = full_module_name)
  30. tree.compilation_source = compsrc
  31. tree.scope = scope
  32. tree.is_pxd = False
  33. return tree
  34. return parse
  35. def parse_pxd_stage_factory(context, scope, module_name):
  36. def parse(source_desc):
  37. tree = context.parse(source_desc, scope, pxd=True,
  38. full_module_name=module_name)
  39. tree.scope = scope
  40. tree.is_pxd = True
  41. return tree
  42. return parse
  43. def generate_pyx_code_stage_factory(options, result):
  44. def generate_pyx_code_stage(module_node):
  45. module_node.process_implementation(options, result)
  46. result.compilation_source = module_node.compilation_source
  47. return result
  48. return generate_pyx_code_stage
  49. def inject_pxd_code_stage_factory(context):
  50. def inject_pxd_code_stage(module_node):
  51. for name, (statlistnode, scope) in context.pxds.items():
  52. module_node.merge_in(statlistnode, scope)
  53. return module_node
  54. return inject_pxd_code_stage
  55. def use_utility_code_definitions(scope, target, seen=None):
  56. if seen is None:
  57. seen = set()
  58. for entry in scope.entries.values():
  59. if entry in seen:
  60. continue
  61. seen.add(entry)
  62. if entry.used and entry.utility_code_definition:
  63. target.use_utility_code(entry.utility_code_definition)
  64. for required_utility in entry.utility_code_definition.requires:
  65. target.use_utility_code(required_utility)
  66. elif entry.as_module:
  67. use_utility_code_definitions(entry.as_module, target, seen)
  68. def sort_utility_codes(utilcodes):
  69. ranks = {}
  70. def get_rank(utilcode):
  71. if utilcode not in ranks:
  72. ranks[utilcode] = 0 # prevent infinite recursion on circular dependencies
  73. original_order = len(ranks)
  74. ranks[utilcode] = 1 + min([get_rank(dep) for dep in utilcode.requires or ()] or [-1]) + original_order * 1e-8
  75. return ranks[utilcode]
  76. for utilcode in utilcodes:
  77. get_rank(utilcode)
  78. return [utilcode for utilcode, _ in sorted(ranks.items(), key=lambda kv: kv[1])]
  79. def normalize_deps(utilcodes):
  80. deps = {}
  81. for utilcode in utilcodes:
  82. deps[utilcode] = utilcode
  83. def unify_dep(dep):
  84. if dep in deps:
  85. return deps[dep]
  86. else:
  87. deps[dep] = dep
  88. return dep
  89. for utilcode in utilcodes:
  90. utilcode.requires = [unify_dep(dep) for dep in utilcode.requires or ()]
  91. def inject_utility_code_stage_factory(context):
  92. def inject_utility_code_stage(module_node):
  93. module_node.prepare_utility_code()
  94. use_utility_code_definitions(context.cython_scope, module_node.scope)
  95. module_node.scope.utility_code_list = sort_utility_codes(module_node.scope.utility_code_list)
  96. normalize_deps(module_node.scope.utility_code_list)
  97. added = []
  98. # Note: the list might be extended inside the loop (if some utility code
  99. # pulls in other utility code, explicitly or implicitly)
  100. for utilcode in module_node.scope.utility_code_list:
  101. if utilcode in added:
  102. continue
  103. added.append(utilcode)
  104. if utilcode.requires:
  105. for dep in utilcode.requires:
  106. if dep not in added and dep not in module_node.scope.utility_code_list:
  107. module_node.scope.utility_code_list.append(dep)
  108. tree = utilcode.get_tree(cython_scope=context.cython_scope)
  109. if tree:
  110. module_node.merge_in(tree.body, tree.scope, merge_scope=True)
  111. return module_node
  112. return inject_utility_code_stage
  113. #
  114. # Pipeline factories
  115. #
  116. def create_pipeline(context, mode, exclude_classes=()):
  117. assert mode in ('pyx', 'py', 'pxd')
  118. from .Visitor import PrintTree
  119. from .ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
  120. from .ParseTreeTransforms import ForwardDeclareTypes, InjectGilHandling, AnalyseDeclarationsTransform
  121. from .ParseTreeTransforms import AnalyseExpressionsTransform, FindInvalidUseOfFusedTypes
  122. from .ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
  123. from .ParseTreeTransforms import TrackNumpyAttributes, InterpretCompilerDirectives, TransformBuiltinMethods
  124. from .ParseTreeTransforms import ExpandInplaceOperators, ParallelRangeTransform
  125. from .ParseTreeTransforms import CalculateQualifiedNamesTransform
  126. from .TypeInference import MarkParallelAssignments, MarkOverflowingArithmetic
  127. from .ParseTreeTransforms import AdjustDefByDirectives, AlignFunctionDefinitions
  128. from .ParseTreeTransforms import RemoveUnreachableCode, GilCheck
  129. from .FlowControl import ControlFlowAnalysis
  130. from .AnalysedTreeTransforms import AutoTestDictTransform
  131. from .AutoDocTransforms import EmbedSignature
  132. from .Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
  133. from .Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
  134. from .Optimize import InlineDefNodeCalls
  135. from .Optimize import ConstantFolding, FinalOptimizePhase
  136. from .Optimize import DropRefcountingTransform
  137. from .Optimize import ConsolidateOverflowCheck
  138. from .Buffer import IntroduceBufferAuxiliaryVars
  139. from .ModuleNode import check_c_declarations, check_c_declarations_pxd
  140. if mode == 'pxd':
  141. _check_c_declarations = check_c_declarations_pxd
  142. _specific_post_parse = PxdPostParse(context)
  143. else:
  144. _check_c_declarations = check_c_declarations
  145. _specific_post_parse = None
  146. if mode == 'py':
  147. _align_function_definitions = AlignFunctionDefinitions(context)
  148. else:
  149. _align_function_definitions = None
  150. # NOTE: This is the "common" parts of the pipeline, which is also
  151. # code in pxd files. So it will be run multiple times in a
  152. # compilation stage.
  153. stages = [
  154. NormalizeTree(context),
  155. PostParse(context),
  156. _specific_post_parse,
  157. TrackNumpyAttributes(),
  158. InterpretCompilerDirectives(context, context.compiler_directives),
  159. ParallelRangeTransform(context),
  160. AdjustDefByDirectives(context),
  161. WithTransform(context),
  162. MarkClosureVisitor(context),
  163. _align_function_definitions,
  164. RemoveUnreachableCode(context),
  165. ConstantFolding(),
  166. FlattenInListTransform(),
  167. DecoratorTransform(context),
  168. ForwardDeclareTypes(context),
  169. InjectGilHandling(),
  170. AnalyseDeclarationsTransform(context),
  171. AutoTestDictTransform(context),
  172. EmbedSignature(context),
  173. EarlyReplaceBuiltinCalls(context), ## Necessary?
  174. TransformBuiltinMethods(context),
  175. MarkParallelAssignments(context),
  176. ControlFlowAnalysis(context),
  177. RemoveUnreachableCode(context),
  178. # MarkParallelAssignments(context),
  179. MarkOverflowingArithmetic(context),
  180. IntroduceBufferAuxiliaryVars(context),
  181. _check_c_declarations,
  182. InlineDefNodeCalls(context),
  183. AnalyseExpressionsTransform(context),
  184. FindInvalidUseOfFusedTypes(context),
  185. ExpandInplaceOperators(context),
  186. IterationTransform(context),
  187. SwitchTransform(context),
  188. OptimizeBuiltinCalls(context), ## Necessary?
  189. CreateClosureClasses(context), ## After all lookups and type inference
  190. CalculateQualifiedNamesTransform(context),
  191. ConsolidateOverflowCheck(context),
  192. DropRefcountingTransform(),
  193. FinalOptimizePhase(context),
  194. GilCheck(),
  195. ]
  196. filtered_stages = []
  197. for s in stages:
  198. if s.__class__ not in exclude_classes:
  199. filtered_stages.append(s)
  200. return filtered_stages
  201. def create_pyx_pipeline(context, options, result, py=False, exclude_classes=()):
  202. if py:
  203. mode = 'py'
  204. else:
  205. mode = 'pyx'
  206. test_support = []
  207. if options.evaluate_tree_assertions:
  208. from ..TestUtils import TreeAssertVisitor
  209. test_support.append(TreeAssertVisitor())
  210. if options.gdb_debug:
  211. from ..Debugger import DebugWriter # requires Py2.5+
  212. from .ParseTreeTransforms import DebugTransform
  213. context.gdb_debug_outputwriter = DebugWriter.CythonDebugWriter(
  214. options.output_dir)
  215. debug_transform = [DebugTransform(context, options, result)]
  216. else:
  217. debug_transform = []
  218. return list(itertools.chain(
  219. [parse_stage_factory(context)],
  220. create_pipeline(context, mode, exclude_classes=exclude_classes),
  221. test_support,
  222. [inject_pxd_code_stage_factory(context),
  223. inject_utility_code_stage_factory(context),
  224. abort_on_errors],
  225. debug_transform,
  226. [generate_pyx_code_stage_factory(options, result)]))
  227. def create_pxd_pipeline(context, scope, module_name):
  228. from .CodeGeneration import ExtractPxdCode
  229. # The pxd pipeline ends up with a CCodeWriter containing the
  230. # code of the pxd, as well as a pxd scope.
  231. return [
  232. parse_pxd_stage_factory(context, scope, module_name)
  233. ] + create_pipeline(context, 'pxd') + [
  234. ExtractPxdCode()
  235. ]
  236. def create_py_pipeline(context, options, result):
  237. return create_pyx_pipeline(context, options, result, py=True)
  238. def create_pyx_as_pxd_pipeline(context, result):
  239. from .ParseTreeTransforms import AlignFunctionDefinitions, \
  240. MarkClosureVisitor, WithTransform, AnalyseDeclarationsTransform
  241. from .Optimize import ConstantFolding, FlattenInListTransform
  242. from .Nodes import StatListNode
  243. pipeline = []
  244. pyx_pipeline = create_pyx_pipeline(context, context.options, result,
  245. exclude_classes=[
  246. AlignFunctionDefinitions,
  247. MarkClosureVisitor,
  248. ConstantFolding,
  249. FlattenInListTransform,
  250. WithTransform
  251. ])
  252. for stage in pyx_pipeline:
  253. pipeline.append(stage)
  254. if isinstance(stage, AnalyseDeclarationsTransform):
  255. # This is the last stage we need.
  256. break
  257. def fake_pxd(root):
  258. for entry in root.scope.entries.values():
  259. if not entry.in_cinclude:
  260. entry.defined_in_pxd = 1
  261. if entry.name == entry.cname and entry.visibility != 'extern':
  262. # Always mangle non-extern cimported entries.
  263. entry.cname = entry.scope.mangle(Naming.func_prefix, entry.name)
  264. return StatListNode(root.pos, stats=[]), root.scope
  265. pipeline.append(fake_pxd)
  266. return pipeline
  267. def insert_into_pipeline(pipeline, transform, before=None, after=None):
  268. """
  269. Insert a new transform into the pipeline after or before an instance of
  270. the given class. e.g.
  271. pipeline = insert_into_pipeline(pipeline, transform,
  272. after=AnalyseDeclarationsTransform)
  273. """
  274. assert before or after
  275. cls = before or after
  276. for i, t in enumerate(pipeline):
  277. if isinstance(t, cls):
  278. break
  279. if after:
  280. i += 1
  281. return pipeline[:i] + [transform] + pipeline[i:]
  282. #
  283. # Running a pipeline
  284. #
  285. _pipeline_entry_points = {}
  286. def run_pipeline(pipeline, source, printtree=True):
  287. from .Visitor import PrintTree
  288. exec_ns = globals().copy() if DebugFlags.debug_verbose_pipeline else None
  289. def run(phase, data):
  290. return phase(data)
  291. error = None
  292. data = source
  293. try:
  294. try:
  295. for phase in pipeline:
  296. if phase is not None:
  297. if not printtree and isinstance(phase, PrintTree):
  298. continue
  299. if DebugFlags.debug_verbose_pipeline:
  300. t = time()
  301. print("Entering pipeline phase %r" % phase)
  302. # create a new wrapper for each step to show the name in profiles
  303. phase_name = getattr(phase, '__name__', type(phase).__name__)
  304. try:
  305. run = _pipeline_entry_points[phase_name]
  306. except KeyError:
  307. exec("def %s(phase, data): return phase(data)" % phase_name, exec_ns)
  308. run = _pipeline_entry_points[phase_name] = exec_ns[phase_name]
  309. data = run(phase, data)
  310. if DebugFlags.debug_verbose_pipeline:
  311. print(" %.3f seconds" % (time() - t))
  312. except CompileError as err:
  313. # err is set
  314. Errors.report_error(err, use_stack=False)
  315. error = err
  316. except InternalError as err:
  317. # Only raise if there was not an earlier error
  318. if Errors.num_errors == 0:
  319. raise
  320. error = err
  321. except AbortError as err:
  322. error = err
  323. return (error, data)