refactor.py 27 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732
  1. # Copyright 2006 Google, Inc. All Rights Reserved.
  2. # Licensed to PSF under a Contributor Agreement.
  3. """Refactoring framework.
  4. Used as a main program, this can refactor any number of files and/or
  5. recursively descend down directories. Imported as a module, this
  6. provides infrastructure to write your own refactoring tool.
  7. """
  8. __author__ = "Guido van Rossum <guido@python.org>"
  9. # Python imports
  10. import io
  11. import os
  12. import pkgutil
  13. import sys
  14. import logging
  15. import operator
  16. import collections
  17. from itertools import chain
  18. # Local imports
  19. from .pgen2 import driver, tokenize, token
  20. from .fixer_util import find_root
  21. from . import pytree, pygram
  22. from . import btm_matcher as bm
  23. def get_all_fix_names(fixer_pkg, remove_prefix=True):
  24. """Return a sorted list of all available fix names in the given package."""
  25. pkg = __import__(fixer_pkg, [], [], ["*"])
  26. fix_names = []
  27. for finder, name, ispkg in pkgutil.iter_modules(pkg.__path__):
  28. if name.startswith("fix_"):
  29. if remove_prefix:
  30. name = name[4:]
  31. fix_names.append(name)
  32. return fix_names
  33. class _EveryNode(Exception):
  34. pass
  35. def _get_head_types(pat):
  36. """ Accepts a pytree Pattern Node and returns a set
  37. of the pattern types which will match first. """
  38. if isinstance(pat, (pytree.NodePattern, pytree.LeafPattern)):
  39. # NodePatters must either have no type and no content
  40. # or a type and content -- so they don't get any farther
  41. # Always return leafs
  42. if pat.type is None:
  43. raise _EveryNode
  44. return {pat.type}
  45. if isinstance(pat, pytree.NegatedPattern):
  46. if pat.content:
  47. return _get_head_types(pat.content)
  48. raise _EveryNode # Negated Patterns don't have a type
  49. if isinstance(pat, pytree.WildcardPattern):
  50. # Recurse on each node in content
  51. r = set()
  52. for p in pat.content:
  53. for x in p:
  54. r.update(_get_head_types(x))
  55. return r
  56. raise Exception("Oh no! I don't understand pattern %s" %(pat))
  57. def _get_headnode_dict(fixer_list):
  58. """ Accepts a list of fixers and returns a dictionary
  59. of head node type --> fixer list. """
  60. head_nodes = collections.defaultdict(list)
  61. every = []
  62. for fixer in fixer_list:
  63. if fixer.pattern:
  64. try:
  65. heads = _get_head_types(fixer.pattern)
  66. except _EveryNode:
  67. every.append(fixer)
  68. else:
  69. for node_type in heads:
  70. head_nodes[node_type].append(fixer)
  71. else:
  72. if fixer._accept_type is not None:
  73. head_nodes[fixer._accept_type].append(fixer)
  74. else:
  75. every.append(fixer)
  76. for node_type in chain(pygram.python_grammar.symbol2number.values(),
  77. pygram.python_grammar.tokens):
  78. head_nodes[node_type].extend(every)
  79. return dict(head_nodes)
  80. def get_fixers_from_package(pkg_name):
  81. """
  82. Return the fully qualified names for fixers in the package pkg_name.
  83. """
  84. return [pkg_name + "." + fix_name
  85. for fix_name in get_all_fix_names(pkg_name, False)]
  86. def _identity(obj):
  87. return obj
  88. def _detect_future_features(source):
  89. have_docstring = False
  90. gen = tokenize.generate_tokens(io.StringIO(source).readline)
  91. def advance():
  92. tok = next(gen)
  93. return tok[0], tok[1]
  94. ignore = frozenset({token.NEWLINE, tokenize.NL, token.COMMENT})
  95. features = set()
  96. try:
  97. while True:
  98. tp, value = advance()
  99. if tp in ignore:
  100. continue
  101. elif tp == token.STRING:
  102. if have_docstring:
  103. break
  104. have_docstring = True
  105. elif tp == token.NAME and value == "from":
  106. tp, value = advance()
  107. if tp != token.NAME or value != "__future__":
  108. break
  109. tp, value = advance()
  110. if tp != token.NAME or value != "import":
  111. break
  112. tp, value = advance()
  113. if tp == token.OP and value == "(":
  114. tp, value = advance()
  115. while tp == token.NAME:
  116. features.add(value)
  117. tp, value = advance()
  118. if tp != token.OP or value != ",":
  119. break
  120. tp, value = advance()
  121. else:
  122. break
  123. except StopIteration:
  124. pass
  125. return frozenset(features)
  126. class FixerError(Exception):
  127. """A fixer could not be loaded."""
  128. class RefactoringTool(object):
  129. _default_options = {"print_function" : False,
  130. "exec_function": False,
  131. "write_unchanged_files" : False}
  132. CLASS_PREFIX = "Fix" # The prefix for fixer classes
  133. FILE_PREFIX = "fix_" # The prefix for modules with a fixer within
  134. def __init__(self, fixer_names, options=None, explicit=None):
  135. """Initializer.
  136. Args:
  137. fixer_names: a list of fixers to import
  138. options: a dict with configuration.
  139. explicit: a list of fixers to run even if they are explicit.
  140. """
  141. self.fixers = fixer_names
  142. self.explicit = explicit or []
  143. self.options = self._default_options.copy()
  144. if options is not None:
  145. self.options.update(options)
  146. self.grammar = pygram.python_grammar.copy()
  147. if self.options['print_function']:
  148. del self.grammar.keywords["print"]
  149. elif self.options['exec_function']:
  150. del self.grammar.keywords["exec"]
  151. # When this is True, the refactor*() methods will call write_file() for
  152. # files processed even if they were not changed during refactoring. If
  153. # and only if the refactor method's write parameter was True.
  154. self.write_unchanged_files = self.options.get("write_unchanged_files")
  155. self.errors = []
  156. self.logger = logging.getLogger("RefactoringTool")
  157. self.fixer_log = []
  158. self.wrote = False
  159. self.driver = driver.Driver(self.grammar,
  160. convert=pytree.convert,
  161. logger=self.logger)
  162. self.pre_order, self.post_order = self.get_fixers()
  163. self.files = [] # List of files that were or should be modified
  164. self.BM = bm.BottomMatcher()
  165. self.bmi_pre_order = [] # Bottom Matcher incompatible fixers
  166. self.bmi_post_order = []
  167. for fixer in chain(self.post_order, self.pre_order):
  168. if fixer.BM_compatible:
  169. self.BM.add_fixer(fixer)
  170. # remove fixers that will be handled by the bottom-up
  171. # matcher
  172. elif fixer in self.pre_order:
  173. self.bmi_pre_order.append(fixer)
  174. elif fixer in self.post_order:
  175. self.bmi_post_order.append(fixer)
  176. self.bmi_pre_order_heads = _get_headnode_dict(self.bmi_pre_order)
  177. self.bmi_post_order_heads = _get_headnode_dict(self.bmi_post_order)
  178. def get_fixers(self):
  179. """Inspects the options to load the requested patterns and handlers.
  180. Returns:
  181. (pre_order, post_order), where pre_order is the list of fixers that
  182. want a pre-order AST traversal, and post_order is the list that want
  183. post-order traversal.
  184. """
  185. pre_order_fixers = []
  186. post_order_fixers = []
  187. for fix_mod_path in self.fixers:
  188. mod = __import__(fix_mod_path, {}, {}, ["*"])
  189. fix_name = fix_mod_path.rsplit(".", 1)[-1]
  190. if fix_name.startswith(self.FILE_PREFIX):
  191. fix_name = fix_name[len(self.FILE_PREFIX):]
  192. parts = fix_name.split("_")
  193. class_name = self.CLASS_PREFIX + "".join([p.title() for p in parts])
  194. try:
  195. fix_class = getattr(mod, class_name)
  196. except AttributeError:
  197. raise FixerError("Can't find %s.%s" % (fix_name, class_name)) from None
  198. fixer = fix_class(self.options, self.fixer_log)
  199. if fixer.explicit and self.explicit is not True and \
  200. fix_mod_path not in self.explicit:
  201. self.log_message("Skipping optional fixer: %s", fix_name)
  202. continue
  203. self.log_debug("Adding transformation: %s", fix_name)
  204. if fixer.order == "pre":
  205. pre_order_fixers.append(fixer)
  206. elif fixer.order == "post":
  207. post_order_fixers.append(fixer)
  208. else:
  209. raise FixerError("Illegal fixer order: %r" % fixer.order)
  210. key_func = operator.attrgetter("run_order")
  211. pre_order_fixers.sort(key=key_func)
  212. post_order_fixers.sort(key=key_func)
  213. return (pre_order_fixers, post_order_fixers)
  214. def log_error(self, msg, *args, **kwds):
  215. """Called when an error occurs."""
  216. raise
  217. def log_message(self, msg, *args):
  218. """Hook to log a message."""
  219. if args:
  220. msg = msg % args
  221. self.logger.info(msg)
  222. def log_debug(self, msg, *args):
  223. if args:
  224. msg = msg % args
  225. self.logger.debug(msg)
  226. def print_output(self, old_text, new_text, filename, equal):
  227. """Called with the old version, new version, and filename of a
  228. refactored file."""
  229. pass
  230. def refactor(self, items, write=False, doctests_only=False):
  231. """Refactor a list of files and directories."""
  232. for dir_or_file in items:
  233. if os.path.isdir(dir_or_file):
  234. self.refactor_dir(dir_or_file, write, doctests_only)
  235. else:
  236. self.refactor_file(dir_or_file, write, doctests_only)
  237. def refactor_dir(self, dir_name, write=False, doctests_only=False):
  238. """Descends down a directory and refactor every Python file found.
  239. Python files are assumed to have a .py extension.
  240. Files and subdirectories starting with '.' are skipped.
  241. """
  242. py_ext = os.extsep + "py"
  243. for dirpath, dirnames, filenames in os.walk(dir_name):
  244. self.log_debug("Descending into %s", dirpath)
  245. dirnames.sort()
  246. filenames.sort()
  247. for name in filenames:
  248. if (not name.startswith(".") and
  249. os.path.splitext(name)[1] == py_ext):
  250. fullname = os.path.join(dirpath, name)
  251. self.refactor_file(fullname, write, doctests_only)
  252. # Modify dirnames in-place to remove subdirs with leading dots
  253. dirnames[:] = [dn for dn in dirnames if not dn.startswith(".")]
  254. def _read_python_source(self, filename):
  255. """
  256. Do our best to decode a Python source file correctly.
  257. """
  258. try:
  259. f = open(filename, "rb")
  260. except OSError as err:
  261. self.log_error("Can't open %s: %s", filename, err)
  262. return None, None
  263. try:
  264. encoding = tokenize.detect_encoding(f.readline)[0]
  265. finally:
  266. f.close()
  267. with io.open(filename, "r", encoding=encoding, newline='') as f:
  268. return f.read(), encoding
  269. def refactor_file(self, filename, write=False, doctests_only=False):
  270. """Refactors a file."""
  271. input, encoding = self._read_python_source(filename)
  272. if input is None:
  273. # Reading the file failed.
  274. return
  275. input += "\n" # Silence certain parse errors
  276. if doctests_only:
  277. self.log_debug("Refactoring doctests in %s", filename)
  278. output = self.refactor_docstring(input, filename)
  279. if self.write_unchanged_files or output != input:
  280. self.processed_file(output, filename, input, write, encoding)
  281. else:
  282. self.log_debug("No doctest changes in %s", filename)
  283. else:
  284. tree = self.refactor_string(input, filename)
  285. if self.write_unchanged_files or (tree and tree.was_changed):
  286. # The [:-1] is to take off the \n we added earlier
  287. self.processed_file(str(tree)[:-1], filename,
  288. write=write, encoding=encoding)
  289. else:
  290. self.log_debug("No changes in %s", filename)
  291. def refactor_string(self, data, name):
  292. """Refactor a given input string.
  293. Args:
  294. data: a string holding the code to be refactored.
  295. name: a human-readable name for use in error/log messages.
  296. Returns:
  297. An AST corresponding to the refactored input stream; None if
  298. there were errors during the parse.
  299. """
  300. features = _detect_future_features(data)
  301. if "print_function" in features:
  302. self.driver.grammar = pygram.python_grammar_no_print_statement
  303. try:
  304. tree = self.driver.parse_string(data)
  305. except Exception as err:
  306. self.log_error("Can't parse %s: %s: %s",
  307. name, err.__class__.__name__, err)
  308. return
  309. finally:
  310. self.driver.grammar = self.grammar
  311. tree.future_features = features
  312. self.log_debug("Refactoring %s", name)
  313. self.refactor_tree(tree, name)
  314. return tree
  315. def refactor_stdin(self, doctests_only=False):
  316. input = sys.stdin.read()
  317. if doctests_only:
  318. self.log_debug("Refactoring doctests in stdin")
  319. output = self.refactor_docstring(input, "<stdin>")
  320. if self.write_unchanged_files or output != input:
  321. self.processed_file(output, "<stdin>", input)
  322. else:
  323. self.log_debug("No doctest changes in stdin")
  324. else:
  325. tree = self.refactor_string(input, "<stdin>")
  326. if self.write_unchanged_files or (tree and tree.was_changed):
  327. self.processed_file(str(tree), "<stdin>", input)
  328. else:
  329. self.log_debug("No changes in stdin")
  330. def refactor_tree(self, tree, name):
  331. """Refactors a parse tree (modifying the tree in place).
  332. For compatible patterns the bottom matcher module is
  333. used. Otherwise the tree is traversed node-to-node for
  334. matches.
  335. Args:
  336. tree: a pytree.Node instance representing the root of the tree
  337. to be refactored.
  338. name: a human-readable name for this tree.
  339. Returns:
  340. True if the tree was modified, False otherwise.
  341. """
  342. for fixer in chain(self.pre_order, self.post_order):
  343. fixer.start_tree(tree, name)
  344. #use traditional matching for the incompatible fixers
  345. self.traverse_by(self.bmi_pre_order_heads, tree.pre_order())
  346. self.traverse_by(self.bmi_post_order_heads, tree.post_order())
  347. # obtain a set of candidate nodes
  348. match_set = self.BM.run(tree.leaves())
  349. while any(match_set.values()):
  350. for fixer in self.BM.fixers:
  351. if fixer in match_set and match_set[fixer]:
  352. #sort by depth; apply fixers from bottom(of the AST) to top
  353. match_set[fixer].sort(key=pytree.Base.depth, reverse=True)
  354. if fixer.keep_line_order:
  355. #some fixers(eg fix_imports) must be applied
  356. #with the original file's line order
  357. match_set[fixer].sort(key=pytree.Base.get_lineno)
  358. for node in list(match_set[fixer]):
  359. if node in match_set[fixer]:
  360. match_set[fixer].remove(node)
  361. try:
  362. find_root(node)
  363. except ValueError:
  364. # this node has been cut off from a
  365. # previous transformation ; skip
  366. continue
  367. if node.fixers_applied and fixer in node.fixers_applied:
  368. # do not apply the same fixer again
  369. continue
  370. results = fixer.match(node)
  371. if results:
  372. new = fixer.transform(node, results)
  373. if new is not None:
  374. node.replace(new)
  375. #new.fixers_applied.append(fixer)
  376. for node in new.post_order():
  377. # do not apply the fixer again to
  378. # this or any subnode
  379. if not node.fixers_applied:
  380. node.fixers_applied = []
  381. node.fixers_applied.append(fixer)
  382. # update the original match set for
  383. # the added code
  384. new_matches = self.BM.run(new.leaves())
  385. for fxr in new_matches:
  386. if not fxr in match_set:
  387. match_set[fxr]=[]
  388. match_set[fxr].extend(new_matches[fxr])
  389. for fixer in chain(self.pre_order, self.post_order):
  390. fixer.finish_tree(tree, name)
  391. return tree.was_changed
  392. def traverse_by(self, fixers, traversal):
  393. """Traverse an AST, applying a set of fixers to each node.
  394. This is a helper method for refactor_tree().
  395. Args:
  396. fixers: a list of fixer instances.
  397. traversal: a generator that yields AST nodes.
  398. Returns:
  399. None
  400. """
  401. if not fixers:
  402. return
  403. for node in traversal:
  404. for fixer in fixers[node.type]:
  405. results = fixer.match(node)
  406. if results:
  407. new = fixer.transform(node, results)
  408. if new is not None:
  409. node.replace(new)
  410. node = new
  411. def processed_file(self, new_text, filename, old_text=None, write=False,
  412. encoding=None):
  413. """
  414. Called when a file has been refactored and there may be changes.
  415. """
  416. self.files.append(filename)
  417. if old_text is None:
  418. old_text = self._read_python_source(filename)[0]
  419. if old_text is None:
  420. return
  421. equal = old_text == new_text
  422. self.print_output(old_text, new_text, filename, equal)
  423. if equal:
  424. self.log_debug("No changes to %s", filename)
  425. if not self.write_unchanged_files:
  426. return
  427. if write:
  428. self.write_file(new_text, filename, old_text, encoding)
  429. else:
  430. self.log_debug("Not writing changes to %s", filename)
  431. def write_file(self, new_text, filename, old_text, encoding=None):
  432. """Writes a string to a file.
  433. It first shows a unified diff between the old text and the new text, and
  434. then rewrites the file; the latter is only done if the write option is
  435. set.
  436. """
  437. try:
  438. fp = io.open(filename, "w", encoding=encoding, newline='')
  439. except OSError as err:
  440. self.log_error("Can't create %s: %s", filename, err)
  441. return
  442. with fp:
  443. try:
  444. fp.write(new_text)
  445. except OSError as err:
  446. self.log_error("Can't write %s: %s", filename, err)
  447. self.log_debug("Wrote changes to %s", filename)
  448. self.wrote = True
  449. PS1 = ">>> "
  450. PS2 = "... "
  451. def refactor_docstring(self, input, filename):
  452. """Refactors a docstring, looking for doctests.
  453. This returns a modified version of the input string. It looks
  454. for doctests, which start with a ">>>" prompt, and may be
  455. continued with "..." prompts, as long as the "..." is indented
  456. the same as the ">>>".
  457. (Unfortunately we can't use the doctest module's parser,
  458. since, like most parsers, it is not geared towards preserving
  459. the original source.)
  460. """
  461. result = []
  462. block = None
  463. block_lineno = None
  464. indent = None
  465. lineno = 0
  466. for line in input.splitlines(keepends=True):
  467. lineno += 1
  468. if line.lstrip().startswith(self.PS1):
  469. if block is not None:
  470. result.extend(self.refactor_doctest(block, block_lineno,
  471. indent, filename))
  472. block_lineno = lineno
  473. block = [line]
  474. i = line.find(self.PS1)
  475. indent = line[:i]
  476. elif (indent is not None and
  477. (line.startswith(indent + self.PS2) or
  478. line == indent + self.PS2.rstrip() + "\n")):
  479. block.append(line)
  480. else:
  481. if block is not None:
  482. result.extend(self.refactor_doctest(block, block_lineno,
  483. indent, filename))
  484. block = None
  485. indent = None
  486. result.append(line)
  487. if block is not None:
  488. result.extend(self.refactor_doctest(block, block_lineno,
  489. indent, filename))
  490. return "".join(result)
  491. def refactor_doctest(self, block, lineno, indent, filename):
  492. """Refactors one doctest.
  493. A doctest is given as a block of lines, the first of which starts
  494. with ">>>" (possibly indented), while the remaining lines start
  495. with "..." (identically indented).
  496. """
  497. try:
  498. tree = self.parse_block(block, lineno, indent)
  499. except Exception as err:
  500. if self.logger.isEnabledFor(logging.DEBUG):
  501. for line in block:
  502. self.log_debug("Source: %s", line.rstrip("\n"))
  503. self.log_error("Can't parse docstring in %s line %s: %s: %s",
  504. filename, lineno, err.__class__.__name__, err)
  505. return block
  506. if self.refactor_tree(tree, filename):
  507. new = str(tree).splitlines(keepends=True)
  508. # Undo the adjustment of the line numbers in wrap_toks() below.
  509. clipped, new = new[:lineno-1], new[lineno-1:]
  510. assert clipped == ["\n"] * (lineno-1), clipped
  511. if not new[-1].endswith("\n"):
  512. new[-1] += "\n"
  513. block = [indent + self.PS1 + new.pop(0)]
  514. if new:
  515. block += [indent + self.PS2 + line for line in new]
  516. return block
  517. def summarize(self):
  518. if self.wrote:
  519. were = "were"
  520. else:
  521. were = "need to be"
  522. if not self.files:
  523. self.log_message("No files %s modified.", were)
  524. else:
  525. self.log_message("Files that %s modified:", were)
  526. for file in self.files:
  527. self.log_message(file)
  528. if self.fixer_log:
  529. self.log_message("Warnings/messages while refactoring:")
  530. for message in self.fixer_log:
  531. self.log_message(message)
  532. if self.errors:
  533. if len(self.errors) == 1:
  534. self.log_message("There was 1 error:")
  535. else:
  536. self.log_message("There were %d errors:", len(self.errors))
  537. for msg, args, kwds in self.errors:
  538. self.log_message(msg, *args, **kwds)
  539. def parse_block(self, block, lineno, indent):
  540. """Parses a block into a tree.
  541. This is necessary to get correct line number / offset information
  542. in the parser diagnostics and embedded into the parse tree.
  543. """
  544. tree = self.driver.parse_tokens(self.wrap_toks(block, lineno, indent))
  545. tree.future_features = frozenset()
  546. return tree
  547. def wrap_toks(self, block, lineno, indent):
  548. """Wraps a tokenize stream to systematically modify start/end."""
  549. tokens = tokenize.generate_tokens(self.gen_lines(block, indent).__next__)
  550. for type, value, (line0, col0), (line1, col1), line_text in tokens:
  551. line0 += lineno - 1
  552. line1 += lineno - 1
  553. # Don't bother updating the columns; this is too complicated
  554. # since line_text would also have to be updated and it would
  555. # still break for tokens spanning lines. Let the user guess
  556. # that the column numbers for doctests are relative to the
  557. # end of the prompt string (PS1 or PS2).
  558. yield type, value, (line0, col0), (line1, col1), line_text
  559. def gen_lines(self, block, indent):
  560. """Generates lines as expected by tokenize from a list of lines.
  561. This strips the first len(indent + self.PS1) characters off each line.
  562. """
  563. prefix1 = indent + self.PS1
  564. prefix2 = indent + self.PS2
  565. prefix = prefix1
  566. for line in block:
  567. if line.startswith(prefix):
  568. yield line[len(prefix):]
  569. elif line == prefix.rstrip() + "\n":
  570. yield "\n"
  571. else:
  572. raise AssertionError("line=%r, prefix=%r" % (line, prefix))
  573. prefix = prefix2
  574. while True:
  575. yield ""
  576. class MultiprocessingUnsupported(Exception):
  577. pass
  578. class MultiprocessRefactoringTool(RefactoringTool):
  579. def __init__(self, *args, **kwargs):
  580. super(MultiprocessRefactoringTool, self).__init__(*args, **kwargs)
  581. self.queue = None
  582. self.output_lock = None
  583. def refactor(self, items, write=False, doctests_only=False,
  584. num_processes=1):
  585. if num_processes == 1:
  586. return super(MultiprocessRefactoringTool, self).refactor(
  587. items, write, doctests_only)
  588. try:
  589. import multiprocessing
  590. except ImportError:
  591. raise MultiprocessingUnsupported
  592. if self.queue is not None:
  593. raise RuntimeError("already doing multiple processes")
  594. self.queue = multiprocessing.JoinableQueue()
  595. self.output_lock = multiprocessing.Lock()
  596. processes = [multiprocessing.Process(target=self._child)
  597. for i in range(num_processes)]
  598. try:
  599. for p in processes:
  600. p.start()
  601. super(MultiprocessRefactoringTool, self).refactor(items, write,
  602. doctests_only)
  603. finally:
  604. self.queue.join()
  605. for i in range(num_processes):
  606. self.queue.put(None)
  607. for p in processes:
  608. if p.is_alive():
  609. p.join()
  610. self.queue = None
  611. def _child(self):
  612. task = self.queue.get()
  613. while task is not None:
  614. args, kwargs = task
  615. try:
  616. super(MultiprocessRefactoringTool, self).refactor_file(
  617. *args, **kwargs)
  618. finally:
  619. self.queue.task_done()
  620. task = self.queue.get()
  621. def refactor_file(self, *args, **kwargs):
  622. if self.queue is not None:
  623. self.queue.put((args, kwargs))
  624. else:
  625. return super(MultiprocessRefactoringTool, self).refactor_file(
  626. *args, **kwargs)