test_refactor.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337
  1. """
  2. Unit tests for refactor.py.
  3. """
  4. import sys
  5. import os
  6. import codecs
  7. import io
  8. import re
  9. import tempfile
  10. import shutil
  11. import unittest
  12. from lib2to3 import refactor, pygram, fixer_base
  13. from lib2to3.pgen2 import token
  14. TEST_DATA_DIR = os.path.join(os.path.dirname(__file__), "data")
  15. FIXER_DIR = os.path.join(TEST_DATA_DIR, "fixers")
  16. sys.path.append(FIXER_DIR)
  17. try:
  18. _DEFAULT_FIXERS = refactor.get_fixers_from_package("myfixes")
  19. finally:
  20. sys.path.pop()
  21. _2TO3_FIXERS = refactor.get_fixers_from_package("lib2to3.fixes")
  22. class TestRefactoringTool(unittest.TestCase):
  23. def setUp(self):
  24. sys.path.append(FIXER_DIR)
  25. def tearDown(self):
  26. sys.path.pop()
  27. def check_instances(self, instances, classes):
  28. for inst, cls in zip(instances, classes):
  29. if not isinstance(inst, cls):
  30. self.fail("%s are not instances of %s" % instances, classes)
  31. def rt(self, options=None, fixers=_DEFAULT_FIXERS, explicit=None):
  32. return refactor.RefactoringTool(fixers, options, explicit)
  33. def test_print_function_option(self):
  34. rt = self.rt({"print_function" : True})
  35. self.assertNotIn("print", rt.grammar.keywords)
  36. self.assertNotIn("print", rt.driver.grammar.keywords)
  37. def test_exec_function_option(self):
  38. rt = self.rt({"exec_function" : True})
  39. self.assertNotIn("exec", rt.grammar.keywords)
  40. self.assertNotIn("exec", rt.driver.grammar.keywords)
  41. def test_write_unchanged_files_option(self):
  42. rt = self.rt()
  43. self.assertFalse(rt.write_unchanged_files)
  44. rt = self.rt({"write_unchanged_files" : True})
  45. self.assertTrue(rt.write_unchanged_files)
  46. def test_fixer_loading_helpers(self):
  47. contents = ["explicit", "first", "last", "parrot", "preorder"]
  48. non_prefixed = refactor.get_all_fix_names("myfixes")
  49. prefixed = refactor.get_all_fix_names("myfixes", False)
  50. full_names = refactor.get_fixers_from_package("myfixes")
  51. self.assertEqual(prefixed, ["fix_" + name for name in contents])
  52. self.assertEqual(non_prefixed, contents)
  53. self.assertEqual(full_names,
  54. ["myfixes.fix_" + name for name in contents])
  55. def test_detect_future_features(self):
  56. run = refactor._detect_future_features
  57. fs = frozenset
  58. empty = fs()
  59. self.assertEqual(run(""), empty)
  60. self.assertEqual(run("from __future__ import print_function"),
  61. fs(("print_function",)))
  62. self.assertEqual(run("from __future__ import generators"),
  63. fs(("generators",)))
  64. self.assertEqual(run("from __future__ import generators, feature"),
  65. fs(("generators", "feature")))
  66. inp = "from __future__ import generators, print_function"
  67. self.assertEqual(run(inp), fs(("generators", "print_function")))
  68. inp ="from __future__ import print_function, generators"
  69. self.assertEqual(run(inp), fs(("print_function", "generators")))
  70. inp = "from __future__ import (print_function,)"
  71. self.assertEqual(run(inp), fs(("print_function",)))
  72. inp = "from __future__ import (generators, print_function)"
  73. self.assertEqual(run(inp), fs(("generators", "print_function")))
  74. inp = "from __future__ import (generators, nested_scopes)"
  75. self.assertEqual(run(inp), fs(("generators", "nested_scopes")))
  76. inp = """from __future__ import generators
  77. from __future__ import print_function"""
  78. self.assertEqual(run(inp), fs(("generators", "print_function")))
  79. invalid = ("from",
  80. "from 4",
  81. "from x",
  82. "from x 5",
  83. "from x im",
  84. "from x import",
  85. "from x import 4",
  86. )
  87. for inp in invalid:
  88. self.assertEqual(run(inp), empty)
  89. inp = "'docstring'\nfrom __future__ import print_function"
  90. self.assertEqual(run(inp), fs(("print_function",)))
  91. inp = "'docstring'\n'somng'\nfrom __future__ import print_function"
  92. self.assertEqual(run(inp), empty)
  93. inp = "# comment\nfrom __future__ import print_function"
  94. self.assertEqual(run(inp), fs(("print_function",)))
  95. inp = "# comment\n'doc'\nfrom __future__ import print_function"
  96. self.assertEqual(run(inp), fs(("print_function",)))
  97. inp = "class x: pass\nfrom __future__ import print_function"
  98. self.assertEqual(run(inp), empty)
  99. def test_get_headnode_dict(self):
  100. class NoneFix(fixer_base.BaseFix):
  101. pass
  102. class FileInputFix(fixer_base.BaseFix):
  103. PATTERN = "file_input< any * >"
  104. class SimpleFix(fixer_base.BaseFix):
  105. PATTERN = "'name'"
  106. no_head = NoneFix({}, [])
  107. with_head = FileInputFix({}, [])
  108. simple = SimpleFix({}, [])
  109. d = refactor._get_headnode_dict([no_head, with_head, simple])
  110. top_fixes = d.pop(pygram.python_symbols.file_input)
  111. self.assertEqual(top_fixes, [with_head, no_head])
  112. name_fixes = d.pop(token.NAME)
  113. self.assertEqual(name_fixes, [simple, no_head])
  114. for fixes in d.values():
  115. self.assertEqual(fixes, [no_head])
  116. def test_fixer_loading(self):
  117. from myfixes.fix_first import FixFirst
  118. from myfixes.fix_last import FixLast
  119. from myfixes.fix_parrot import FixParrot
  120. from myfixes.fix_preorder import FixPreorder
  121. rt = self.rt()
  122. pre, post = rt.get_fixers()
  123. self.check_instances(pre, [FixPreorder])
  124. self.check_instances(post, [FixFirst, FixParrot, FixLast])
  125. def test_naughty_fixers(self):
  126. self.assertRaises(ImportError, self.rt, fixers=["not_here"])
  127. self.assertRaises(refactor.FixerError, self.rt, fixers=["no_fixer_cls"])
  128. self.assertRaises(refactor.FixerError, self.rt, fixers=["bad_order"])
  129. def test_refactor_string(self):
  130. rt = self.rt()
  131. input = "def parrot(): pass\n\n"
  132. tree = rt.refactor_string(input, "<test>")
  133. self.assertNotEqual(str(tree), input)
  134. input = "def f(): pass\n\n"
  135. tree = rt.refactor_string(input, "<test>")
  136. self.assertEqual(str(tree), input)
  137. def test_refactor_stdin(self):
  138. class MyRT(refactor.RefactoringTool):
  139. def print_output(self, old_text, new_text, filename, equal):
  140. results.extend([old_text, new_text, filename, equal])
  141. results = []
  142. rt = MyRT(_DEFAULT_FIXERS)
  143. save = sys.stdin
  144. sys.stdin = io.StringIO("def parrot(): pass\n\n")
  145. try:
  146. rt.refactor_stdin()
  147. finally:
  148. sys.stdin = save
  149. expected = ["def parrot(): pass\n\n",
  150. "def cheese(): pass\n\n",
  151. "<stdin>", False]
  152. self.assertEqual(results, expected)
  153. def check_file_refactoring(self, test_file, fixers=_2TO3_FIXERS,
  154. options=None, mock_log_debug=None,
  155. actually_write=True):
  156. test_file = self.init_test_file(test_file)
  157. old_contents = self.read_file(test_file)
  158. rt = self.rt(fixers=fixers, options=options)
  159. if mock_log_debug:
  160. rt.log_debug = mock_log_debug
  161. rt.refactor_file(test_file)
  162. self.assertEqual(old_contents, self.read_file(test_file))
  163. if not actually_write:
  164. return
  165. rt.refactor_file(test_file, True)
  166. new_contents = self.read_file(test_file)
  167. self.assertNotEqual(old_contents, new_contents)
  168. return new_contents
  169. def init_test_file(self, test_file):
  170. tmpdir = tempfile.mkdtemp(prefix="2to3-test_refactor")
  171. self.addCleanup(shutil.rmtree, tmpdir)
  172. shutil.copy(test_file, tmpdir)
  173. test_file = os.path.join(tmpdir, os.path.basename(test_file))
  174. os.chmod(test_file, 0o644)
  175. return test_file
  176. def read_file(self, test_file):
  177. with open(test_file, "rb") as fp:
  178. return fp.read()
  179. def refactor_file(self, test_file, fixers=_2TO3_FIXERS):
  180. test_file = self.init_test_file(test_file)
  181. old_contents = self.read_file(test_file)
  182. rt = self.rt(fixers=fixers)
  183. rt.refactor_file(test_file, True)
  184. new_contents = self.read_file(test_file)
  185. return old_contents, new_contents
  186. def test_refactor_file(self):
  187. test_file = os.path.join(FIXER_DIR, "parrot_example.py")
  188. self.check_file_refactoring(test_file, _DEFAULT_FIXERS)
  189. def test_refactor_file_write_unchanged_file(self):
  190. test_file = os.path.join(FIXER_DIR, "parrot_example.py")
  191. debug_messages = []
  192. def recording_log_debug(msg, *args):
  193. debug_messages.append(msg % args)
  194. self.check_file_refactoring(test_file, fixers=(),
  195. options={"write_unchanged_files": True},
  196. mock_log_debug=recording_log_debug,
  197. actually_write=False)
  198. # Testing that it logged this message when write=False was passed is
  199. # sufficient to see that it did not bail early after "No changes".
  200. message_regex = r"Not writing changes to .*%s" % \
  201. re.escape(os.sep + os.path.basename(test_file))
  202. for message in debug_messages:
  203. if "Not writing changes" in message:
  204. self.assertRegex(message, message_regex)
  205. break
  206. else:
  207. self.fail("%r not matched in %r" % (message_regex, debug_messages))
  208. def test_refactor_dir(self):
  209. def check(structure, expected):
  210. def mock_refactor_file(self, f, *args):
  211. got.append(f)
  212. save_func = refactor.RefactoringTool.refactor_file
  213. refactor.RefactoringTool.refactor_file = mock_refactor_file
  214. rt = self.rt()
  215. got = []
  216. dir = tempfile.mkdtemp(prefix="2to3-test_refactor")
  217. try:
  218. os.mkdir(os.path.join(dir, "a_dir"))
  219. for fn in structure:
  220. open(os.path.join(dir, fn), "wb").close()
  221. rt.refactor_dir(dir)
  222. finally:
  223. refactor.RefactoringTool.refactor_file = save_func
  224. shutil.rmtree(dir)
  225. self.assertEqual(got,
  226. [os.path.join(dir, path) for path in expected])
  227. check([], [])
  228. tree = ["nothing",
  229. "hi.py",
  230. ".dumb",
  231. ".after.py",
  232. "notpy.npy",
  233. "sappy"]
  234. expected = ["hi.py"]
  235. check(tree, expected)
  236. tree = ["hi.py",
  237. os.path.join("a_dir", "stuff.py")]
  238. check(tree, tree)
  239. def test_file_encoding(self):
  240. fn = os.path.join(TEST_DATA_DIR, "different_encoding.py")
  241. self.check_file_refactoring(fn)
  242. def test_false_file_encoding(self):
  243. fn = os.path.join(TEST_DATA_DIR, "false_encoding.py")
  244. data = self.check_file_refactoring(fn)
  245. def test_bom(self):
  246. fn = os.path.join(TEST_DATA_DIR, "bom.py")
  247. data = self.check_file_refactoring(fn)
  248. self.assertTrue(data.startswith(codecs.BOM_UTF8))
  249. def test_crlf_newlines(self):
  250. old_sep = os.linesep
  251. os.linesep = "\r\n"
  252. try:
  253. fn = os.path.join(TEST_DATA_DIR, "crlf.py")
  254. fixes = refactor.get_fixers_from_package("lib2to3.fixes")
  255. self.check_file_refactoring(fn, fixes)
  256. finally:
  257. os.linesep = old_sep
  258. def test_crlf_unchanged(self):
  259. fn = os.path.join(TEST_DATA_DIR, "crlf.py")
  260. old, new = self.refactor_file(fn)
  261. self.assertIn(b"\r\n", old)
  262. self.assertIn(b"\r\n", new)
  263. self.assertNotIn(b"\r\r\n", new)
  264. def test_refactor_docstring(self):
  265. rt = self.rt()
  266. doc = """
  267. >>> example()
  268. 42
  269. """
  270. out = rt.refactor_docstring(doc, "<test>")
  271. self.assertEqual(out, doc)
  272. doc = """
  273. >>> def parrot():
  274. ... return 43
  275. """
  276. out = rt.refactor_docstring(doc, "<test>")
  277. self.assertNotEqual(out, doc)
  278. def test_explicit(self):
  279. from myfixes.fix_explicit import FixExplicit
  280. rt = self.rt(fixers=["myfixes.fix_explicit"])
  281. self.assertEqual(len(rt.post_order), 0)
  282. rt = self.rt(explicit=["myfixes.fix_explicit"])
  283. for fix in rt.post_order:
  284. if isinstance(fix, FixExplicit):
  285. break
  286. else:
  287. self.fail("explicit fixer not loaded")