123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413 |
- from typing import Set as tSet
- from warnings import warn
- import inspect
- from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
- from .utils import expand_tuples
- import itertools as itl
- class MDNotImplementedError(NotImplementedError):
- """ A NotImplementedError for multiple dispatch """
- ### Functions for on_ambiguity
- def ambiguity_warn(dispatcher, ambiguities):
- """ Raise warning when ambiguity is detected
- Parameters
- ----------
- dispatcher : Dispatcher
- The dispatcher on which the ambiguity was detected
- ambiguities : set
- Set of type signature pairs that are ambiguous within this dispatcher
- See Also:
- Dispatcher.add
- warning_text
- """
- warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
- class RaiseNotImplementedError:
- """Raise ``NotImplementedError`` when called."""
- def __init__(self, dispatcher):
- self.dispatcher = dispatcher
- def __call__(self, *args, **kwargs):
- types = tuple(type(a) for a in args)
- raise NotImplementedError(
- "Ambiguous signature for %s: <%s>" % (
- self.dispatcher.name, str_signature(types)
- ))
- def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
- """
- If super signature for ambiguous types is duplicate types, ignore it.
- Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
- Parameters
- ----------
- dispatcher : Dispatcher
- The dispatcher on which the ambiguity was detected
- ambiguities : set
- Set of type signature pairs that are ambiguous within this dispatcher
- See Also:
- Dispatcher.add
- ambiguity_warn
- """
- for amb in ambiguities:
- signature = tuple(super_signature(amb))
- if len(set(signature)) == 1:
- continue
- dispatcher.add(
- signature, RaiseNotImplementedError(dispatcher),
- on_ambiguity=ambiguity_register_error_ignore_dup
- )
- ###
- _unresolved_dispatchers = set() # type: tSet[Dispatcher]
- _resolve = [True]
- def halt_ordering():
- _resolve[0] = False
- def restart_ordering(on_ambiguity=ambiguity_warn):
- _resolve[0] = True
- while _unresolved_dispatchers:
- dispatcher = _unresolved_dispatchers.pop()
- dispatcher.reorder(on_ambiguity=on_ambiguity)
- class Dispatcher:
- """ Dispatch methods based on type signature
- Use ``dispatch`` to add implementations
- Examples
- --------
- >>> from sympy.multipledispatch import dispatch
- >>> @dispatch(int)
- ... def f(x):
- ... return x + 1
- >>> @dispatch(float)
- ... def f(x): # noqa: F811
- ... return x - 1
- >>> f(3)
- 4
- >>> f(3.0)
- 2.0
- """
- __slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
- def __init__(self, name, doc=None):
- self.name = self.__name__ = name
- self.funcs = dict()
- self._cache = dict()
- self.ordering = []
- self.doc = doc
- def register(self, *types, **kwargs):
- """ Register dispatcher with new implementation
- >>> from sympy.multipledispatch.dispatcher import Dispatcher
- >>> f = Dispatcher('f')
- >>> @f.register(int)
- ... def inc(x):
- ... return x + 1
- >>> @f.register(float)
- ... def dec(x):
- ... return x - 1
- >>> @f.register(list)
- ... @f.register(tuple)
- ... def reverse(x):
- ... return x[::-1]
- >>> f(1)
- 2
- >>> f(1.0)
- 0.0
- >>> f([1, 2, 3])
- [3, 2, 1]
- """
- def _(func):
- self.add(types, func, **kwargs)
- return func
- return _
- @classmethod
- def get_func_params(cls, func):
- if hasattr(inspect, "signature"):
- sig = inspect.signature(func)
- return sig.parameters.values()
- @classmethod
- def get_func_annotations(cls, func):
- """ Get annotations of function positional parameters
- """
- params = cls.get_func_params(func)
- if params:
- Parameter = inspect.Parameter
- params = (param for param in params
- if param.kind in
- (Parameter.POSITIONAL_ONLY,
- Parameter.POSITIONAL_OR_KEYWORD))
- annotations = tuple(
- param.annotation
- for param in params)
- if not any(ann is Parameter.empty for ann in annotations):
- return annotations
- def add(self, signature, func, on_ambiguity=ambiguity_warn):
- """ Add new types/method pair to dispatcher
- >>> from sympy.multipledispatch import Dispatcher
- >>> D = Dispatcher('add')
- >>> D.add((int, int), lambda x, y: x + y)
- >>> D.add((float, float), lambda x, y: x + y)
- >>> D(1, 2)
- 3
- >>> D(1, 2.0)
- Traceback (most recent call last):
- ...
- NotImplementedError: Could not find signature for add: <int, float>
- When ``add`` detects a warning it calls the ``on_ambiguity`` callback
- with a dispatcher/itself, and a set of ambiguous type signature pairs
- as inputs. See ``ambiguity_warn`` for an example.
- """
- # Handle annotations
- if not signature:
- annotations = self.get_func_annotations(func)
- if annotations:
- signature = annotations
- # Handle union types
- if any(isinstance(typ, tuple) for typ in signature):
- for typs in expand_tuples(signature):
- self.add(typs, func, on_ambiguity)
- return
- for typ in signature:
- if not isinstance(typ, type):
- str_sig = ', '.join(c.__name__ if isinstance(c, type)
- else str(c) for c in signature)
- raise TypeError("Tried to dispatch on non-type: %s\n"
- "In signature: <%s>\n"
- "In function: %s" %
- (typ, str_sig, self.name))
- self.funcs[signature] = func
- self.reorder(on_ambiguity=on_ambiguity)
- self._cache.clear()
- def reorder(self, on_ambiguity=ambiguity_warn):
- if _resolve[0]:
- self.ordering = ordering(self.funcs)
- amb = ambiguities(self.funcs)
- if amb:
- on_ambiguity(self, amb)
- else:
- _unresolved_dispatchers.add(self)
- def __call__(self, *args, **kwargs):
- types = tuple([type(arg) for arg in args])
- try:
- func = self._cache[types]
- except KeyError:
- func = self.dispatch(*types)
- if not func:
- raise NotImplementedError(
- 'Could not find signature for %s: <%s>' %
- (self.name, str_signature(types)))
- self._cache[types] = func
- try:
- return func(*args, **kwargs)
- except MDNotImplementedError:
- funcs = self.dispatch_iter(*types)
- next(funcs) # burn first
- for func in funcs:
- try:
- return func(*args, **kwargs)
- except MDNotImplementedError:
- pass
- raise NotImplementedError("Matching functions for "
- "%s: <%s> found, but none completed successfully"
- % (self.name, str_signature(types)))
- def __str__(self):
- return "<dispatched %s>" % self.name
- __repr__ = __str__
- def dispatch(self, *types):
- """ Deterimine appropriate implementation for this type signature
- This method is internal. Users should call this object as a function.
- Implementation resolution occurs within the ``__call__`` method.
- >>> from sympy.multipledispatch import dispatch
- >>> @dispatch(int)
- ... def inc(x):
- ... return x + 1
- >>> implementation = inc.dispatch(int)
- >>> implementation(3)
- 4
- >>> print(inc.dispatch(float))
- None
- See Also:
- ``sympy.multipledispatch.conflict`` - module to determine resolution order
- """
- if types in self.funcs:
- return self.funcs[types]
- try:
- return next(self.dispatch_iter(*types))
- except StopIteration:
- return None
- def dispatch_iter(self, *types):
- n = len(types)
- for signature in self.ordering:
- if len(signature) == n and all(map(issubclass, types, signature)):
- result = self.funcs[signature]
- yield result
- def resolve(self, types):
- """ Deterimine appropriate implementation for this type signature
- .. deprecated:: 0.4.4
- Use ``dispatch(*types)`` instead
- """
- warn("resolve() is deprecated, use dispatch(*types)",
- DeprecationWarning)
- return self.dispatch(*types)
- def __getstate__(self):
- return {'name': self.name,
- 'funcs': self.funcs}
- def __setstate__(self, d):
- self.name = d['name']
- self.funcs = d['funcs']
- self.ordering = ordering(self.funcs)
- self._cache = dict()
- @property
- def __doc__(self):
- docs = ["Multiply dispatched method: %s" % self.name]
- if self.doc:
- docs.append(self.doc)
- other = []
- for sig in self.ordering[::-1]:
- func = self.funcs[sig]
- if func.__doc__:
- s = 'Inputs: <%s>\n' % str_signature(sig)
- s += '-' * len(s) + '\n'
- s += func.__doc__.strip()
- docs.append(s)
- else:
- other.append(str_signature(sig))
- if other:
- docs.append('Other signatures:\n ' + '\n '.join(other))
- return '\n\n'.join(docs)
- def _help(self, *args):
- return self.dispatch(*map(type, args)).__doc__
- def help(self, *args, **kwargs):
- """ Print docstring for the function corresponding to inputs """
- print(self._help(*args))
- def _source(self, *args):
- func = self.dispatch(*map(type, args))
- if not func:
- raise TypeError("No function found")
- return source(func)
- def source(self, *args, **kwargs):
- """ Print source code for the function corresponding to inputs """
- print(self._source(*args))
- def source(func):
- s = 'File: %s\n\n' % inspect.getsourcefile(func)
- s = s + inspect.getsource(func)
- return s
- class MethodDispatcher(Dispatcher):
- """ Dispatch methods based on type signature
- See Also:
- Dispatcher
- """
- @classmethod
- def get_func_params(cls, func):
- if hasattr(inspect, "signature"):
- sig = inspect.signature(func)
- return itl.islice(sig.parameters.values(), 1, None)
- def __get__(self, instance, owner):
- self.obj = instance
- self.cls = owner
- return self
- def __call__(self, *args, **kwargs):
- types = tuple([type(arg) for arg in args])
- func = self.dispatch(*types)
- if not func:
- raise NotImplementedError('Could not find signature for %s: <%s>' %
- (self.name, str_signature(types)))
- return func(self.obj, *args, **kwargs)
- def str_signature(sig):
- """ String representation of type signature
- >>> from sympy.multipledispatch.dispatcher import str_signature
- >>> str_signature((int, float))
- 'int, float'
- """
- return ', '.join(cls.__name__ for cls in sig)
- def warning_text(name, amb):
- """ The text for ambiguity warnings """
- text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
- text += "The following signatures may result in ambiguous behavior:\n"
- for pair in amb:
- text += "\t" + \
- ', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
- text += "\n\nConsider making the following additions:\n\n"
- text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
- + ')\ndef %s(...)' % name for s in amb])
- return text
|