123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275 |
- #
- # TreeFragments - parsing of strings to trees
- #
- """
- Support for parsing strings into code trees.
- """
- from __future__ import absolute_import
- import re
- from io import StringIO
- from .Scanning import PyrexScanner, StringSourceDescriptor
- from .Symtab import ModuleScope
- from . import PyrexTypes
- from .Visitor import VisitorTransform
- from .Nodes import Node, StatListNode
- from .ExprNodes import NameNode
- from .StringEncoding import _unicode
- from . import Parsing
- from . import Main
- from . import UtilNodes
- class StringParseContext(Main.Context):
- def __init__(self, name, include_directories=None, compiler_directives=None, cpp=False):
- if include_directories is None:
- include_directories = []
- if compiler_directives is None:
- compiler_directives = {}
- # TODO: see if "language_level=3" also works for our internal code here.
- Main.Context.__init__(self, include_directories, compiler_directives, cpp=cpp, language_level=2)
- self.module_name = name
- def find_module(self, module_name, relative_to=None, pos=None, need_pxd=1, absolute_fallback=True):
- if module_name not in (self.module_name, 'cython'):
- raise AssertionError("Not yet supporting any cimports/includes from string code snippets")
- return ModuleScope(module_name, parent_module=None, context=self)
- def parse_from_strings(name, code, pxds=None, level=None, initial_pos=None,
- context=None, allow_struct_enum_decorator=False):
- """
- Utility method to parse a (unicode) string of code. This is mostly
- used for internal Cython compiler purposes (creating code snippets
- that transforms should emit, as well as unit testing).
- code - a unicode string containing Cython (module-level) code
- name - a descriptive name for the code source (to use in error messages etc.)
- RETURNS
- The tree, i.e. a ModuleNode. The ModuleNode's scope attribute is
- set to the scope used when parsing.
- """
- if context is None:
- context = StringParseContext(name)
- # Since source files carry an encoding, it makes sense in this context
- # to use a unicode string so that code fragments don't have to bother
- # with encoding. This means that test code passed in should not have an
- # encoding header.
- assert isinstance(code, _unicode), "unicode code snippets only please"
- encoding = "UTF-8"
- module_name = name
- if initial_pos is None:
- initial_pos = (name, 1, 0)
- code_source = StringSourceDescriptor(name, code)
- scope = context.find_module(module_name, pos=initial_pos, need_pxd=False)
- buf = StringIO(code)
- scanner = PyrexScanner(buf, code_source, source_encoding = encoding,
- scope = scope, context = context, initial_pos = initial_pos)
- ctx = Parsing.Ctx(allow_struct_enum_decorator=allow_struct_enum_decorator)
- if level is None:
- tree = Parsing.p_module(scanner, 0, module_name, ctx=ctx)
- tree.scope = scope
- tree.is_pxd = False
- else:
- tree = Parsing.p_code(scanner, level=level, ctx=ctx)
- tree.scope = scope
- return tree
- class TreeCopier(VisitorTransform):
- def visit_Node(self, node):
- if node is None:
- return node
- else:
- c = node.clone_node()
- self.visitchildren(c)
- return c
- class ApplyPositionAndCopy(TreeCopier):
- def __init__(self, pos):
- super(ApplyPositionAndCopy, self).__init__()
- self.pos = pos
- def visit_Node(self, node):
- copy = super(ApplyPositionAndCopy, self).visit_Node(node)
- copy.pos = self.pos
- return copy
- class TemplateTransform(VisitorTransform):
- """
- Makes a copy of a template tree while doing substitutions.
- A dictionary "substitutions" should be passed in when calling
- the transform; mapping names to replacement nodes. Then replacement
- happens like this:
- - If an ExprStatNode contains a single NameNode, whose name is
- a key in the substitutions dictionary, the ExprStatNode is
- replaced with a copy of the tree given in the dictionary.
- It is the responsibility of the caller that the replacement
- node is a valid statement.
- - If a single NameNode is otherwise encountered, it is replaced
- if its name is listed in the substitutions dictionary in the
- same way. It is the responsibility of the caller to make sure
- that the replacement nodes is a valid expression.
- Also a list "temps" should be passed. Any names listed will
- be transformed into anonymous, temporary names.
- Currently supported for tempnames is:
- NameNode
- (various function and class definition nodes etc. should be added to this)
- Each replacement node gets the position of the substituted node
- recursively applied to every member node.
- """
- temp_name_counter = 0
- def __call__(self, node, substitutions, temps, pos):
- self.substitutions = substitutions
- self.pos = pos
- tempmap = {}
- temphandles = []
- for temp in temps:
- TemplateTransform.temp_name_counter += 1
- handle = UtilNodes.TempHandle(PyrexTypes.py_object_type)
- tempmap[temp] = handle
- temphandles.append(handle)
- self.tempmap = tempmap
- result = super(TemplateTransform, self).__call__(node)
- if temps:
- result = UtilNodes.TempsBlockNode(self.get_pos(node),
- temps=temphandles,
- body=result)
- return result
- def get_pos(self, node):
- if self.pos:
- return self.pos
- else:
- return node.pos
- def visit_Node(self, node):
- if node is None:
- return None
- else:
- c = node.clone_node()
- if self.pos is not None:
- c.pos = self.pos
- self.visitchildren(c)
- return c
- def try_substitution(self, node, key):
- sub = self.substitutions.get(key)
- if sub is not None:
- pos = self.pos
- if pos is None: pos = node.pos
- return ApplyPositionAndCopy(pos)(sub)
- else:
- return self.visit_Node(node) # make copy as usual
- def visit_NameNode(self, node):
- temphandle = self.tempmap.get(node.name)
- if temphandle:
- # Replace name with temporary
- return temphandle.ref(self.get_pos(node))
- else:
- return self.try_substitution(node, node.name)
- def visit_ExprStatNode(self, node):
- # If an expression-as-statement consists of only a replaceable
- # NameNode, we replace the entire statement, not only the NameNode
- if isinstance(node.expr, NameNode):
- return self.try_substitution(node, node.expr.name)
- else:
- return self.visit_Node(node)
- def copy_code_tree(node):
- return TreeCopier()(node)
- _match_indent = re.compile(u"^ *").match
- def strip_common_indent(lines):
- """Strips empty lines and common indentation from the list of strings given in lines"""
- # TODO: Facilitate textwrap.indent instead
- lines = [x for x in lines if x.strip() != u""]
- if lines:
- minindent = min([len(_match_indent(x).group(0)) for x in lines])
- lines = [x[minindent:] for x in lines]
- return lines
- class TreeFragment(object):
- def __init__(self, code, name=None, pxds=None, temps=None, pipeline=None, level=None, initial_pos=None):
- if pxds is None:
- pxds = {}
- if temps is None:
- temps = []
- if pipeline is None:
- pipeline = []
- if not name:
- name = "(tree fragment)"
- if isinstance(code, _unicode):
- def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n")))
- fmt_code = fmt(code)
- fmt_pxds = {}
- for key, value in pxds.items():
- fmt_pxds[key] = fmt(value)
- mod = t = parse_from_strings(name, fmt_code, fmt_pxds, level=level, initial_pos=initial_pos)
- if level is None:
- t = t.body # Make sure a StatListNode is at the top
- if not isinstance(t, StatListNode):
- t = StatListNode(pos=mod.pos, stats=[t])
- for transform in pipeline:
- if transform is None:
- continue
- t = transform(t)
- self.root = t
- elif isinstance(code, Node):
- if pxds:
- raise NotImplementedError()
- self.root = code
- else:
- raise ValueError("Unrecognized code format (accepts unicode and Node)")
- self.temps = temps
- def copy(self):
- return copy_code_tree(self.root)
- def substitute(self, nodes=None, temps=None, pos = None):
- if nodes is None:
- nodes = {}
- if temps is None:
- temps = []
- return TemplateTransform()(self.root,
- substitutions = nodes,
- temps = self.temps + temps, pos = pos)
- class SetPosTransform(VisitorTransform):
- def __init__(self, pos):
- super(SetPosTransform, self).__init__()
- self.pos = pos
- def visit_Node(self, node):
- node.pos = self.pos
- self.visitchildren(node)
- return node
|