visitor.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141
  1. """Generic visitor pattern implementation for Python objects."""
  2. import enum
  3. class Visitor(object):
  4. defaultStop = False
  5. @classmethod
  6. def _register(celf, clazzes_attrs):
  7. assert celf != Visitor, "Subclass Visitor instead."
  8. if "_visitors" not in celf.__dict__:
  9. celf._visitors = {}
  10. def wrapper(method):
  11. assert method.__name__ == "visit"
  12. for clazzes, attrs in clazzes_attrs:
  13. if type(clazzes) != tuple:
  14. clazzes = (clazzes,)
  15. if type(attrs) == str:
  16. attrs = (attrs,)
  17. for clazz in clazzes:
  18. _visitors = celf._visitors.setdefault(clazz, {})
  19. for attr in attrs:
  20. assert attr not in _visitors, (
  21. "Oops, class '%s' has visitor function for '%s' defined already."
  22. % (clazz.__name__, attr)
  23. )
  24. _visitors[attr] = method
  25. return None
  26. return wrapper
  27. @classmethod
  28. def register(celf, clazzes):
  29. if type(clazzes) != tuple:
  30. clazzes = (clazzes,)
  31. return celf._register([(clazzes, (None,))])
  32. @classmethod
  33. def register_attr(celf, clazzes, attrs):
  34. clazzes_attrs = []
  35. if type(clazzes) != tuple:
  36. clazzes = (clazzes,)
  37. if type(attrs) == str:
  38. attrs = (attrs,)
  39. for clazz in clazzes:
  40. clazzes_attrs.append((clazz, attrs))
  41. return celf._register(clazzes_attrs)
  42. @classmethod
  43. def register_attrs(celf, clazzes_attrs):
  44. return celf._register(clazzes_attrs)
  45. @classmethod
  46. def _visitorsFor(celf, thing, _default={}):
  47. typ = type(thing)
  48. for celf in celf.mro():
  49. _visitors = getattr(celf, "_visitors", None)
  50. if _visitors is None:
  51. break
  52. m = celf._visitors.get(typ, None)
  53. if m is not None:
  54. return m
  55. return _default
  56. def visitObject(self, obj, *args, **kwargs):
  57. """Called to visit an object. This function loops over all non-private
  58. attributes of the objects and calls any user-registered (via
  59. @register_attr() or @register_attrs()) visit() functions.
  60. If there is no user-registered visit function, of if there is and it
  61. returns True, or it returns None (or doesn't return anything) and
  62. visitor.defaultStop is False (default), then the visitor will proceed
  63. to call self.visitAttr()"""
  64. keys = sorted(vars(obj).keys())
  65. _visitors = self._visitorsFor(obj)
  66. defaultVisitor = _visitors.get("*", None)
  67. for key in keys:
  68. if key[0] == "_":
  69. continue
  70. value = getattr(obj, key)
  71. visitorFunc = _visitors.get(key, defaultVisitor)
  72. if visitorFunc is not None:
  73. ret = visitorFunc(self, obj, key, value, *args, **kwargs)
  74. if ret == False or (ret is None and self.defaultStop):
  75. continue
  76. self.visitAttr(obj, key, value, *args, **kwargs)
  77. def visitAttr(self, obj, attr, value, *args, **kwargs):
  78. """Called to visit an attribute of an object."""
  79. self.visit(value, *args, **kwargs)
  80. def visitList(self, obj, *args, **kwargs):
  81. """Called to visit any value that is a list."""
  82. for value in obj:
  83. self.visit(value, *args, **kwargs)
  84. def visitDict(self, obj, *args, **kwargs):
  85. """Called to visit any value that is a dictionary."""
  86. for value in obj.values():
  87. self.visit(value, *args, **kwargs)
  88. def visitLeaf(self, obj, *args, **kwargs):
  89. """Called to visit any value that is not an object, list,
  90. or dictionary."""
  91. pass
  92. def visit(self, obj, *args, **kwargs):
  93. """This is the main entry to the visitor. The visitor will visit object
  94. obj.
  95. The visitor will first determine if there is a registered (via
  96. @register()) visit function for the type of object. If there is, it
  97. will be called, and (visitor, obj, *args, **kwargs) will be passed to
  98. the user visit function.
  99. If there is no user-registered visit function, of if there is and it
  100. returns True, or it returns None (or doesn't return anything) and
  101. visitor.defaultStop is False (default), then the visitor will proceed
  102. to dispatch to one of self.visitObject(), self.visitList(),
  103. self.visitDict(), or self.visitLeaf() (any of which can be overriden in
  104. a subclass)."""
  105. visitorFunc = self._visitorsFor(obj).get(None, None)
  106. if visitorFunc is not None:
  107. ret = visitorFunc(self, obj, *args, **kwargs)
  108. if ret == False or (ret is None and self.defaultStop):
  109. return
  110. if hasattr(obj, "__dict__") and not isinstance(obj, enum.Enum):
  111. self.visitObject(obj, *args, **kwargs)
  112. elif isinstance(obj, list):
  113. self.visitList(obj, *args, **kwargs)
  114. elif isinstance(obj, dict):
  115. self.visitDict(obj, *args, **kwargs)
  116. else:
  117. self.visitLeaf(obj, *args, **kwargs)