btm_matcher.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163
  1. """A bottom-up tree matching algorithm implementation meant to speed
  2. up 2to3's matching process. After the tree patterns are reduced to
  3. their rarest linear path, a linear Aho-Corasick automaton is
  4. created. The linear automaton traverses the linear paths from the
  5. leaves to the root of the AST and returns a set of nodes for further
  6. matching. This reduces significantly the number of candidate nodes."""
  7. __author__ = "George Boutsioukis <gboutsioukis@gmail.com>"
  8. import logging
  9. import itertools
  10. from collections import defaultdict
  11. from . import pytree
  12. from .btm_utils import reduce_tree
  13. class BMNode(object):
  14. """Class for a node of the Aho-Corasick automaton used in matching"""
  15. count = itertools.count()
  16. def __init__(self):
  17. self.transition_table = {}
  18. self.fixers = []
  19. self.id = next(BMNode.count)
  20. self.content = ''
  21. class BottomMatcher(object):
  22. """The main matcher class. After instantiating the patterns should
  23. be added using the add_fixer method"""
  24. def __init__(self):
  25. self.match = set()
  26. self.root = BMNode()
  27. self.nodes = [self.root]
  28. self.fixers = []
  29. self.logger = logging.getLogger("RefactoringTool")
  30. def add_fixer(self, fixer):
  31. """Reduces a fixer's pattern tree to a linear path and adds it
  32. to the matcher(a common Aho-Corasick automaton). The fixer is
  33. appended on the matching states and called when they are
  34. reached"""
  35. self.fixers.append(fixer)
  36. tree = reduce_tree(fixer.pattern_tree)
  37. linear = tree.get_linear_subpattern()
  38. match_nodes = self.add(linear, start=self.root)
  39. for match_node in match_nodes:
  40. match_node.fixers.append(fixer)
  41. def add(self, pattern, start):
  42. "Recursively adds a linear pattern to the AC automaton"
  43. #print("adding pattern", pattern, "to", start)
  44. if not pattern:
  45. #print("empty pattern")
  46. return [start]
  47. if isinstance(pattern[0], tuple):
  48. #alternatives
  49. #print("alternatives")
  50. match_nodes = []
  51. for alternative in pattern[0]:
  52. #add all alternatives, and add the rest of the pattern
  53. #to each end node
  54. end_nodes = self.add(alternative, start=start)
  55. for end in end_nodes:
  56. match_nodes.extend(self.add(pattern[1:], end))
  57. return match_nodes
  58. else:
  59. #single token
  60. #not last
  61. if pattern[0] not in start.transition_table:
  62. #transition did not exist, create new
  63. next_node = BMNode()
  64. start.transition_table[pattern[0]] = next_node
  65. else:
  66. #transition exists already, follow
  67. next_node = start.transition_table[pattern[0]]
  68. if pattern[1:]:
  69. end_nodes = self.add(pattern[1:], start=next_node)
  70. else:
  71. end_nodes = [next_node]
  72. return end_nodes
  73. def run(self, leaves):
  74. """The main interface with the bottom matcher. The tree is
  75. traversed from the bottom using the constructed
  76. automaton. Nodes are only checked once as the tree is
  77. retraversed. When the automaton fails, we give it one more
  78. shot(in case the above tree matches as a whole with the
  79. rejected leaf), then we break for the next leaf. There is the
  80. special case of multiple arguments(see code comments) where we
  81. recheck the nodes
  82. Args:
  83. The leaves of the AST tree to be matched
  84. Returns:
  85. A dictionary of node matches with fixers as the keys
  86. """
  87. current_ac_node = self.root
  88. results = defaultdict(list)
  89. for leaf in leaves:
  90. current_ast_node = leaf
  91. while current_ast_node:
  92. current_ast_node.was_checked = True
  93. for child in current_ast_node.children:
  94. # multiple statements, recheck
  95. if isinstance(child, pytree.Leaf) and child.value == ";":
  96. current_ast_node.was_checked = False
  97. break
  98. if current_ast_node.type == 1:
  99. #name
  100. node_token = current_ast_node.value
  101. else:
  102. node_token = current_ast_node.type
  103. if node_token in current_ac_node.transition_table:
  104. #token matches
  105. current_ac_node = current_ac_node.transition_table[node_token]
  106. for fixer in current_ac_node.fixers:
  107. results[fixer].append(current_ast_node)
  108. else:
  109. #matching failed, reset automaton
  110. current_ac_node = self.root
  111. if (current_ast_node.parent is not None
  112. and current_ast_node.parent.was_checked):
  113. #the rest of the tree upwards has been checked, next leaf
  114. break
  115. #recheck the rejected node once from the root
  116. if node_token in current_ac_node.transition_table:
  117. #token matches
  118. current_ac_node = current_ac_node.transition_table[node_token]
  119. for fixer in current_ac_node.fixers:
  120. results[fixer].append(current_ast_node)
  121. current_ast_node = current_ast_node.parent
  122. return results
  123. def print_ac(self):
  124. "Prints a graphviz diagram of the BM automaton(for debugging)"
  125. print("digraph g{")
  126. def print_node(node):
  127. for subnode_key in node.transition_table.keys():
  128. subnode = node.transition_table[subnode_key]
  129. print("%d -> %d [label=%s] //%s" %
  130. (node.id, subnode.id, type_repr(subnode_key), str(subnode.fixers)))
  131. if subnode_key == 1:
  132. print(subnode.content)
  133. print_node(subnode)
  134. print_node(self.root)
  135. print("}")
  136. # taken from pytree.py for debugging; only used by print_ac
  137. _type_reprs = {}
  138. def type_repr(type_num):
  139. global _type_reprs
  140. if not _type_reprs:
  141. from .pygram import python_symbols
  142. # printing tokens is possible but not as useful
  143. # from .pgen2 import token // token.__dict__.items():
  144. for name, val in python_symbols.__dict__.items():
  145. if type(val) == int: _type_reprs[val] = name
  146. return _type_reprs.setdefault(type_num, type_num)