utils.py 3.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. from collections import OrderedDict
  2. def expand_tuples(L):
  3. """
  4. >>> from sympy.multipledispatch.utils import expand_tuples
  5. >>> expand_tuples([1, (2, 3)])
  6. [(1, 2), (1, 3)]
  7. >>> expand_tuples([1, 2])
  8. [(1, 2)]
  9. """
  10. if not L:
  11. return [()]
  12. elif not isinstance(L[0], tuple):
  13. rest = expand_tuples(L[1:])
  14. return [(L[0],) + t for t in rest]
  15. else:
  16. rest = expand_tuples(L[1:])
  17. return [(item,) + t for t in rest for item in L[0]]
  18. # Taken from theano/theano/gof/sched.py
  19. # Avoids licensing issues because this was written by Matthew Rocklin
  20. def _toposort(edges):
  21. """ Topological sort algorithm by Kahn [1] - O(nodes + vertices)
  22. inputs:
  23. edges - a dict of the form {a: {b, c}} where b and c depend on a
  24. outputs:
  25. L - an ordered list of nodes that satisfy the dependencies of edges
  26. >>> from sympy.multipledispatch.utils import _toposort
  27. >>> _toposort({1: (2, 3), 2: (3, )})
  28. [1, 2, 3]
  29. Closely follows the wikipedia page [2]
  30. [1] Kahn, Arthur B. (1962), "Topological sorting of large networks",
  31. Communications of the ACM
  32. [2] https://en.wikipedia.org/wiki/Toposort#Algorithms
  33. """
  34. incoming_edges = reverse_dict(edges)
  35. incoming_edges = {k: set(val) for k, val in incoming_edges.items()}
  36. S = OrderedDict.fromkeys(v for v in edges if v not in incoming_edges)
  37. L = []
  38. while S:
  39. n, _ = S.popitem()
  40. L.append(n)
  41. for m in edges.get(n, ()):
  42. assert n in incoming_edges[m]
  43. incoming_edges[m].remove(n)
  44. if not incoming_edges[m]:
  45. S[m] = None
  46. if any(incoming_edges.get(v, None) for v in edges):
  47. raise ValueError("Input has cycles")
  48. return L
  49. def reverse_dict(d):
  50. """Reverses direction of dependence dict
  51. >>> d = {'a': (1, 2), 'b': (2, 3), 'c':()}
  52. >>> reverse_dict(d) # doctest: +SKIP
  53. {1: ('a',), 2: ('a', 'b'), 3: ('b',)}
  54. :note: dict order are not deterministic. As we iterate on the
  55. input dict, it make the output of this function depend on the
  56. dict order. So this function output order should be considered
  57. as undeterministic.
  58. """
  59. result = {}
  60. for key in d:
  61. for val in d[key]:
  62. result[val] = result.get(val, tuple()) + (key, )
  63. return result
  64. # Taken from toolz
  65. # Avoids licensing issues because this version was authored by Matthew Rocklin
  66. def groupby(func, seq):
  67. """ Group a collection by a key function
  68. >>> from sympy.multipledispatch.utils import groupby
  69. >>> names = ['Alice', 'Bob', 'Charlie', 'Dan', 'Edith', 'Frank']
  70. >>> groupby(len, names) # doctest: +SKIP
  71. {3: ['Bob', 'Dan'], 5: ['Alice', 'Edith', 'Frank'], 7: ['Charlie']}
  72. >>> iseven = lambda x: x % 2 == 0
  73. >>> groupby(iseven, [1, 2, 3, 4, 5, 6, 7, 8]) # doctest: +SKIP
  74. {False: [1, 3, 5, 7], True: [2, 4, 6, 8]}
  75. See Also:
  76. ``countby``
  77. """
  78. d = dict()
  79. for item in seq:
  80. key = func(item)
  81. if key not in d:
  82. d[key] = list()
  83. d[key].append(item)
  84. return d