dispatcher.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413
  1. from typing import Set as tSet
  2. from warnings import warn
  3. import inspect
  4. from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
  5. from .utils import expand_tuples
  6. import itertools as itl
  7. class MDNotImplementedError(NotImplementedError):
  8. """ A NotImplementedError for multiple dispatch """
  9. ### Functions for on_ambiguity
  10. def ambiguity_warn(dispatcher, ambiguities):
  11. """ Raise warning when ambiguity is detected
  12. Parameters
  13. ----------
  14. dispatcher : Dispatcher
  15. The dispatcher on which the ambiguity was detected
  16. ambiguities : set
  17. Set of type signature pairs that are ambiguous within this dispatcher
  18. See Also:
  19. Dispatcher.add
  20. warning_text
  21. """
  22. warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
  23. class RaiseNotImplementedError:
  24. """Raise ``NotImplementedError`` when called."""
  25. def __init__(self, dispatcher):
  26. self.dispatcher = dispatcher
  27. def __call__(self, *args, **kwargs):
  28. types = tuple(type(a) for a in args)
  29. raise NotImplementedError(
  30. "Ambiguous signature for %s: <%s>" % (
  31. self.dispatcher.name, str_signature(types)
  32. ))
  33. def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
  34. """
  35. If super signature for ambiguous types is duplicate types, ignore it.
  36. Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
  37. Parameters
  38. ----------
  39. dispatcher : Dispatcher
  40. The dispatcher on which the ambiguity was detected
  41. ambiguities : set
  42. Set of type signature pairs that are ambiguous within this dispatcher
  43. See Also:
  44. Dispatcher.add
  45. ambiguity_warn
  46. """
  47. for amb in ambiguities:
  48. signature = tuple(super_signature(amb))
  49. if len(set(signature)) == 1:
  50. continue
  51. dispatcher.add(
  52. signature, RaiseNotImplementedError(dispatcher),
  53. on_ambiguity=ambiguity_register_error_ignore_dup
  54. )
  55. ###
  56. _unresolved_dispatchers = set() # type: tSet[Dispatcher]
  57. _resolve = [True]
  58. def halt_ordering():
  59. _resolve[0] = False
  60. def restart_ordering(on_ambiguity=ambiguity_warn):
  61. _resolve[0] = True
  62. while _unresolved_dispatchers:
  63. dispatcher = _unresolved_dispatchers.pop()
  64. dispatcher.reorder(on_ambiguity=on_ambiguity)
  65. class Dispatcher:
  66. """ Dispatch methods based on type signature
  67. Use ``dispatch`` to add implementations
  68. Examples
  69. --------
  70. >>> from sympy.multipledispatch import dispatch
  71. >>> @dispatch(int)
  72. ... def f(x):
  73. ... return x + 1
  74. >>> @dispatch(float)
  75. ... def f(x): # noqa: F811
  76. ... return x - 1
  77. >>> f(3)
  78. 4
  79. >>> f(3.0)
  80. 2.0
  81. """
  82. __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
  83. def __init__(self, name, doc=None):
  84. self.name = self.__name__ = name
  85. self.funcs = dict()
  86. self._cache = dict()
  87. self.ordering = []
  88. self.doc = doc
  89. def register(self, *types, **kwargs):
  90. """ Register dispatcher with new implementation
  91. >>> from sympy.multipledispatch.dispatcher import Dispatcher
  92. >>> f = Dispatcher('f')
  93. >>> @f.register(int)
  94. ... def inc(x):
  95. ... return x + 1
  96. >>> @f.register(float)
  97. ... def dec(x):
  98. ... return x - 1
  99. >>> @f.register(list)
  100. ... @f.register(tuple)
  101. ... def reverse(x):
  102. ... return x[::-1]
  103. >>> f(1)
  104. 2
  105. >>> f(1.0)
  106. 0.0
  107. >>> f([1, 2, 3])
  108. [3, 2, 1]
  109. """
  110. def _(func):
  111. self.add(types, func, **kwargs)
  112. return func
  113. return _
  114. @classmethod
  115. def get_func_params(cls, func):
  116. if hasattr(inspect, "signature"):
  117. sig = inspect.signature(func)
  118. return sig.parameters.values()
  119. @classmethod
  120. def get_func_annotations(cls, func):
  121. """ Get annotations of function positional parameters
  122. """
  123. params = cls.get_func_params(func)
  124. if params:
  125. Parameter = inspect.Parameter
  126. params = (param for param in params
  127. if param.kind in
  128. (Parameter.POSITIONAL_ONLY,
  129. Parameter.POSITIONAL_OR_KEYWORD))
  130. annotations = tuple(
  131. param.annotation
  132. for param in params)
  133. if not any(ann is Parameter.empty for ann in annotations):
  134. return annotations
  135. def add(self, signature, func, on_ambiguity=ambiguity_warn):
  136. """ Add new types/method pair to dispatcher
  137. >>> from sympy.multipledispatch import Dispatcher
  138. >>> D = Dispatcher('add')
  139. >>> D.add((int, int), lambda x, y: x + y)
  140. >>> D.add((float, float), lambda x, y: x + y)
  141. >>> D(1, 2)
  142. 3
  143. >>> D(1, 2.0)
  144. Traceback (most recent call last):
  145. ...
  146. NotImplementedError: Could not find signature for add: <int, float>
  147. When ``add`` detects a warning it calls the ``on_ambiguity`` callback
  148. with a dispatcher/itself, and a set of ambiguous type signature pairs
  149. as inputs. See ``ambiguity_warn`` for an example.
  150. """
  151. # Handle annotations
  152. if not signature:
  153. annotations = self.get_func_annotations(func)
  154. if annotations:
  155. signature = annotations
  156. # Handle union types
  157. if any(isinstance(typ, tuple) for typ in signature):
  158. for typs in expand_tuples(signature):
  159. self.add(typs, func, on_ambiguity)
  160. return
  161. for typ in signature:
  162. if not isinstance(typ, type):
  163. str_sig = ', '.join(c.__name__ if isinstance(c, type)
  164. else str(c) for c in signature)
  165. raise TypeError("Tried to dispatch on non-type: %s\n"
  166. "In signature: <%s>\n"
  167. "In function: %s" %
  168. (typ, str_sig, self.name))
  169. self.funcs[signature] = func
  170. self.reorder(on_ambiguity=on_ambiguity)
  171. self._cache.clear()
  172. def reorder(self, on_ambiguity=ambiguity_warn):
  173. if _resolve[0]:
  174. self.ordering = ordering(self.funcs)
  175. amb = ambiguities(self.funcs)
  176. if amb:
  177. on_ambiguity(self, amb)
  178. else:
  179. _unresolved_dispatchers.add(self)
  180. def __call__(self, *args, **kwargs):
  181. types = tuple([type(arg) for arg in args])
  182. try:
  183. func = self._cache[types]
  184. except KeyError:
  185. func = self.dispatch(*types)
  186. if not func:
  187. raise NotImplementedError(
  188. 'Could not find signature for %s: <%s>' %
  189. (self.name, str_signature(types)))
  190. self._cache[types] = func
  191. try:
  192. return func(*args, **kwargs)
  193. except MDNotImplementedError:
  194. funcs = self.dispatch_iter(*types)
  195. next(funcs) # burn first
  196. for func in funcs:
  197. try:
  198. return func(*args, **kwargs)
  199. except MDNotImplementedError:
  200. pass
  201. raise NotImplementedError("Matching functions for "
  202. "%s: <%s> found, but none completed successfully"
  203. % (self.name, str_signature(types)))
  204. def __str__(self):
  205. return "<dispatched %s>" % self.name
  206. __repr__ = __str__
  207. def dispatch(self, *types):
  208. """ Deterimine appropriate implementation for this type signature
  209. This method is internal. Users should call this object as a function.
  210. Implementation resolution occurs within the ``__call__`` method.
  211. >>> from sympy.multipledispatch import dispatch
  212. >>> @dispatch(int)
  213. ... def inc(x):
  214. ... return x + 1
  215. >>> implementation = inc.dispatch(int)
  216. >>> implementation(3)
  217. 4
  218. >>> print(inc.dispatch(float))
  219. None
  220. See Also:
  221. ``sympy.multipledispatch.conflict`` - module to determine resolution order
  222. """
  223. if types in self.funcs:
  224. return self.funcs[types]
  225. try:
  226. return next(self.dispatch_iter(*types))
  227. except StopIteration:
  228. return None
  229. def dispatch_iter(self, *types):
  230. n = len(types)
  231. for signature in self.ordering:
  232. if len(signature) == n and all(map(issubclass, types, signature)):
  233. result = self.funcs[signature]
  234. yield result
  235. def resolve(self, types):
  236. """ Deterimine appropriate implementation for this type signature
  237. .. deprecated:: 0.4.4
  238. Use ``dispatch(*types)`` instead
  239. """
  240. warn("resolve() is deprecated, use dispatch(*types)",
  241. DeprecationWarning)
  242. return self.dispatch(*types)
  243. def __getstate__(self):
  244. return {'name': self.name,
  245. 'funcs': self.funcs}
  246. def __setstate__(self, d):
  247. self.name = d['name']
  248. self.funcs = d['funcs']
  249. self.ordering = ordering(self.funcs)
  250. self._cache = dict()
  251. @property
  252. def __doc__(self):
  253. docs = ["Multiply dispatched method: %s" % self.name]
  254. if self.doc:
  255. docs.append(self.doc)
  256. other = []
  257. for sig in self.ordering[::-1]:
  258. func = self.funcs[sig]
  259. if func.__doc__:
  260. s = 'Inputs: <%s>\n' % str_signature(sig)
  261. s += '-' * len(s) + '\n'
  262. s += func.__doc__.strip()
  263. docs.append(s)
  264. else:
  265. other.append(str_signature(sig))
  266. if other:
  267. docs.append('Other signatures:\n ' + '\n '.join(other))
  268. return '\n\n'.join(docs)
  269. def _help(self, *args):
  270. return self.dispatch(*map(type, args)).__doc__
  271. def help(self, *args, **kwargs):
  272. """ Print docstring for the function corresponding to inputs """
  273. print(self._help(*args))
  274. def _source(self, *args):
  275. func = self.dispatch(*map(type, args))
  276. if not func:
  277. raise TypeError("No function found")
  278. return source(func)
  279. def source(self, *args, **kwargs):
  280. """ Print source code for the function corresponding to inputs """
  281. print(self._source(*args))
  282. def source(func):
  283. s = 'File: %s\n\n' % inspect.getsourcefile(func)
  284. s = s + inspect.getsource(func)
  285. return s
  286. class MethodDispatcher(Dispatcher):
  287. """ Dispatch methods based on type signature
  288. See Also:
  289. Dispatcher
  290. """
  291. @classmethod
  292. def get_func_params(cls, func):
  293. if hasattr(inspect, "signature"):
  294. sig = inspect.signature(func)
  295. return itl.islice(sig.parameters.values(), 1, None)
  296. def __get__(self, instance, owner):
  297. self.obj = instance
  298. self.cls = owner
  299. return self
  300. def __call__(self, *args, **kwargs):
  301. types = tuple([type(arg) for arg in args])
  302. func = self.dispatch(*types)
  303. if not func:
  304. raise NotImplementedError('Could not find signature for %s: <%s>' %
  305. (self.name, str_signature(types)))
  306. return func(self.obj, *args, **kwargs)
  307. def str_signature(sig):
  308. """ String representation of type signature
  309. >>> from sympy.multipledispatch.dispatcher import str_signature
  310. >>> str_signature((int, float))
  311. 'int, float'
  312. """
  313. return ', '.join(cls.__name__ for cls in sig)
  314. def warning_text(name, amb):
  315. """ The text for ambiguity warnings """
  316. text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
  317. text += "The following signatures may result in ambiguous behavior:\n"
  318. for pair in amb:
  319. text += "\t" + \
  320. ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
  321. text += "\n\nConsider making the following additions:\n\n"
  322. text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
  323. + ')\ndef %s(...)' % name for s in amb])
  324. return text