fix_exitfunc.py 2.4 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172
  1. """
  2. Convert use of sys.exitfunc to use the atexit module.
  3. """
  4. # Author: Benjamin Peterson
  5. from lib2to3 import pytree, fixer_base
  6. from lib2to3.fixer_util import Name, Attr, Call, Comma, Newline, syms
  7. class FixExitfunc(fixer_base.BaseFix):
  8. keep_line_order = True
  9. BM_compatible = True
  10. PATTERN = """
  11. (
  12. sys_import=import_name<'import'
  13. ('sys'
  14. |
  15. dotted_as_names< (any ',')* 'sys' (',' any)* >
  16. )
  17. >
  18. |
  19. expr_stmt<
  20. power< 'sys' trailer< '.' 'exitfunc' > >
  21. '=' func=any >
  22. )
  23. """
  24. def __init__(self, *args):
  25. super(FixExitfunc, self).__init__(*args)
  26. def start_tree(self, tree, filename):
  27. super(FixExitfunc, self).start_tree(tree, filename)
  28. self.sys_import = None
  29. def transform(self, node, results):
  30. # First, find the sys import. We'll just hope it's global scope.
  31. if "sys_import" in results:
  32. if self.sys_import is None:
  33. self.sys_import = results["sys_import"]
  34. return
  35. func = results["func"].clone()
  36. func.prefix = ""
  37. register = pytree.Node(syms.power,
  38. Attr(Name("atexit"), Name("register"))
  39. )
  40. call = Call(register, [func], node.prefix)
  41. node.replace(call)
  42. if self.sys_import is None:
  43. # That's interesting.
  44. self.warning(node, "Can't find sys import; Please add an atexit "
  45. "import at the top of your file.")
  46. return
  47. # Now add an atexit import after the sys import.
  48. names = self.sys_import.children[1]
  49. if names.type == syms.dotted_as_names:
  50. names.append_child(Comma())
  51. names.append_child(Name("atexit", " "))
  52. else:
  53. containing_stmt = self.sys_import.parent
  54. position = containing_stmt.children.index(self.sys_import)
  55. stmt_container = containing_stmt.parent
  56. new_import = pytree.Node(syms.import_name,
  57. [Name("import"), Name("atexit", " ")]
  58. )
  59. new = pytree.Node(syms.simple_stmt, [new_import])
  60. containing_stmt.insert_child(position + 1, Newline())
  61. containing_stmt.insert_child(position + 2, new)