TreeFragment.py 9.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275
  1. #
  2. # TreeFragments - parsing of strings to trees
  3. #
  4. """
  5. Support for parsing strings into code trees.
  6. """
  7. from __future__ import absolute_import
  8. import re
  9. from io import StringIO
  10. from .Scanning import PyrexScanner, StringSourceDescriptor
  11. from .Symtab import ModuleScope
  12. from . import PyrexTypes
  13. from .Visitor import VisitorTransform
  14. from .Nodes import Node, StatListNode
  15. from .ExprNodes import NameNode
  16. from .StringEncoding import _unicode
  17. from . import Parsing
  18. from . import Main
  19. from . import UtilNodes
  20. class StringParseContext(Main.Context):
  21. def __init__(self, name, include_directories=None, compiler_directives=None, cpp=False):
  22. if include_directories is None:
  23. include_directories = []
  24. if compiler_directives is None:
  25. compiler_directives = {}
  26. # TODO: see if "language_level=3" also works for our internal code here.
  27. Main.Context.__init__(self, include_directories, compiler_directives, cpp=cpp, language_level=2)
  28. self.module_name = name
  29. def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, absolute_fallback=True):
  30. if module_name not in (self.module_name, 'cython'):
  31. raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
  32. return ModuleScope(module_name, parent_module=None, context=self)
  33. def parse_from_strings(name, code, pxds=None, level=None, initial_pos=None,
  34. context=None, allow_struct_enum_decorator=False):
  35. """
  36. Utility method to parse a (unicode) string of code. This is mostly
  37. used for internal Cython compiler purposes (creating code snippets
  38. that transforms should emit, as well as unit testing).
  39. code - a unicode string containing Cython (module-level) code
  40. name - a descriptive name for the code source (to use in error messages etc.)
  41. RETURNS
  42. The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
  43. set to the scope used when parsing.
  44. """
  45. if context is None:
  46. context = StringParseContext(name)
  47. # Since source files carry an encoding, it makes sense in this context
  48. # to use a unicode string so that code fragments don't have to bother
  49. # with encoding. This means that test code passed in should not have an
  50. # encoding header.
  51. assert isinstance(code, _unicode), "unicode code snippets only please"
  52. encoding = "UTF-8"
  53. module_name = name
  54. if initial_pos is None:
  55. initial_pos = (name, 1, 0)
  56. code_source = StringSourceDescriptor(name, code)
  57. scope = context.find_module(module_name, pos=initial_pos, need_pxd=False)
  58. buf = StringIO(code)
  59. scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
  60. scope = scope, context = context, initial_pos = initial_pos)
  61. ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)
  62. if level is None:
  63. tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
  64. tree.scope = scope
  65. tree.is_pxd = False
  66. else:
  67. tree = Parsing.p_code(scanner, level=level, ctx=ctx)
  68. tree.scope = scope
  69. return tree
  70. class TreeCopier(VisitorTransform):
  71. def visit_Node(self, node):
  72. if node is None:
  73. return node
  74. else:
  75. c = node.clone_node()
  76. self.visitchildren(c)
  77. return c
  78. class ApplyPositionAndCopy(TreeCopier):
  79. def __init__(self, pos):
  80. super(ApplyPositionAndCopy, self).__init__()
  81. self.pos = pos
  82. def visit_Node(self, node):
  83. copy = super(ApplyPositionAndCopy, self).visit_Node(node)
  84. copy.pos = self.pos
  85. return copy
  86. class TemplateTransform(VisitorTransform):
  87. """
  88. Makes a copy of a template tree while doing substitutions.
  89. A dictionary "substitutions" should be passed in when calling
  90. the transform; mapping names to replacement nodes. Then replacement
  91. happens like this:
  92. - If an ExprStatNode contains a single NameNode, whose name is
  93. a key in the substitutions dictionary, the ExprStatNode is
  94. replaced with a copy of the tree given in the dictionary.
  95. It is the responsibility of the caller that the replacement
  96. node is a valid statement.
  97. - If a single NameNode is otherwise encountered, it is replaced
  98. if its name is listed in the substitutions dictionary in the
  99. same way. It is the responsibility of the caller to make sure
  100. that the replacement nodes is a valid expression.
  101. Also a list "temps" should be passed. Any names listed will
  102. be transformed into anonymous, temporary names.
  103. Currently supported for tempnames is:
  104. NameNode
  105. (various function and class definition nodes etc. should be added to this)
  106. Each replacement node gets the position of the substituted node
  107. recursively applied to every member node.
  108. """
  109. temp_name_counter = 0
  110. def __call__(self, node, substitutions, temps, pos):
  111. self.substitutions = substitutions
  112. self.pos = pos
  113. tempmap = {}
  114. temphandles = []
  115. for temp in temps:
  116. TemplateTransform.temp_name_counter += 1
  117. handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
  118. tempmap[temp] = handle
  119. temphandles.append(handle)
  120. self.tempmap = tempmap
  121. result = super(TemplateTransform, self).__call__(node)
  122. if temps:
  123. result = UtilNodes.TempsBlockNode(self.get_pos(node),
  124. temps=temphandles,
  125. body=result)
  126. return result
  127. def get_pos(self, node):
  128. if self.pos:
  129. return self.pos
  130. else:
  131. return node.pos
  132. def visit_Node(self, node):
  133. if node is None:
  134. return None
  135. else:
  136. c = node.clone_node()
  137. if self.pos is not None:
  138. c.pos = self.pos
  139. self.visitchildren(c)
  140. return c
  141. def try_substitution(self, node, key):
  142. sub = self.substitutions.get(key)
  143. if sub is not None:
  144. pos = self.pos
  145. if pos is None: pos = node.pos
  146. return ApplyPositionAndCopy(pos)(sub)
  147. else:
  148. return self.visit_Node(node) # make copy as usual
  149. def visit_NameNode(self, node):
  150. temphandle = self.tempmap.get(node.name)
  151. if temphandle:
  152. # Replace name with temporary
  153. return temphandle.ref(self.get_pos(node))
  154. else:
  155. return self.try_substitution(node, node.name)
  156. def visit_ExprStatNode(self, node):
  157. # If an expression-as-statement consists of only a replaceable
  158. # NameNode, we replace the entire statement, not only the NameNode
  159. if isinstance(node.expr, NameNode):
  160. return self.try_substitution(node, node.expr.name)
  161. else:
  162. return self.visit_Node(node)
  163. def copy_code_tree(node):
  164. return TreeCopier()(node)
  165. _match_indent = re.compile(u"^ *").match
  166. def strip_common_indent(lines):
  167. """Strips empty lines and common indentation from the list of strings given in lines"""
  168. # TODO: Facilitate textwrap.indent instead
  169. lines = [x for x in lines if x.strip() != u""]
  170. if lines:
  171. minindent = min([len(_match_indent(x).group(0)) for x in lines])
  172. lines = [x[minindent:] for x in lines]
  173. return lines
  174. class TreeFragment(object):
  175. def __init__(self, code, name=None, pxds=None, temps=None, pipeline=None, level=None, initial_pos=None):
  176. if pxds is None:
  177. pxds = {}
  178. if temps is None:
  179. temps = []
  180. if pipeline is None:
  181. pipeline = []
  182. if not name:
  183. name = "(tree fragment)"
  184. if isinstance(code, _unicode):
  185. def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
  186. fmt_code = fmt(code)
  187. fmt_pxds = {}
  188. for key, value in pxds.items():
  189. fmt_pxds[key] = fmt(value)
  190. mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
  191. if level is None:
  192. t = t.body # Make sure a StatListNode is at the top
  193. if not isinstance(t, StatListNode):
  194. t = StatListNode(pos=mod.pos, stats=[t])
  195. for transform in pipeline:
  196. if transform is None:
  197. continue
  198. t = transform(t)
  199. self.root = t
  200. elif isinstance(code, Node):
  201. if pxds:
  202. raise NotImplementedError()
  203. self.root = code
  204. else:
  205. raise ValueError("Unrecognized code format (accepts unicode and Node)")
  206. self.temps = temps
  207. def copy(self):
  208. return copy_code_tree(self.root)
  209. def substitute(self, nodes=None, temps=None, pos = None):
  210. if nodes is None:
  211. nodes = {}
  212. if temps is None:
  213. temps = []
  214. return TemplateTransform()(self.root,
  215. substitutions = nodes,
  216. temps = self.temps + temps, pos = pos)
  217. class SetPosTransform(VisitorTransform):
  218. def __init__(self, pos):
  219. super(SetPosTransform, self).__init__()
  220. self.pos = pos
  221. def visit_Node(self, node):
  222. node.pos = self.pos
  223. self.visitchildren(node)
  224. return node