usympy.py 3.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124
  1. """ SymPy interface to Unification engine
  2. See sympy.unify for module level docstring
  3. See sympy.unify.core for algorithmic docstring """
  4. from sympy.core import Basic, Add, Mul, Pow
  5. from sympy.core.operations import AssocOp, LatticeOp
  6. from sympy.matrices import MatAdd, MatMul, MatrixExpr
  7. from sympy.sets.sets import Union, Intersection, FiniteSet
  8. from sympy.unify.core import Compound, Variable, CondVariable
  9. from sympy.unify import core
  10. basic_new_legal = [MatrixExpr]
  11. eval_false_legal = [AssocOp, Pow, FiniteSet]
  12. illegal = [LatticeOp]
  13. def sympy_associative(op):
  14. assoc_ops = (AssocOp, MatAdd, MatMul, Union, Intersection, FiniteSet)
  15. return any(issubclass(op, aop) for aop in assoc_ops)
  16. def sympy_commutative(op):
  17. comm_ops = (Add, MatAdd, Union, Intersection, FiniteSet)
  18. return any(issubclass(op, cop) for cop in comm_ops)
  19. def is_associative(x):
  20. return isinstance(x, Compound) and sympy_associative(x.op)
  21. def is_commutative(x):
  22. if not isinstance(x, Compound):
  23. return False
  24. if sympy_commutative(x.op):
  25. return True
  26. if issubclass(x.op, Mul):
  27. return all(construct(arg).is_commutative for arg in x.args)
  28. def mk_matchtype(typ):
  29. def matchtype(x):
  30. return (isinstance(x, typ) or
  31. isinstance(x, Compound) and issubclass(x.op, typ))
  32. return matchtype
  33. def deconstruct(s, variables=()):
  34. """ Turn a SymPy object into a Compound """
  35. if s in variables:
  36. return Variable(s)
  37. if isinstance(s, (Variable, CondVariable)):
  38. return s
  39. if not isinstance(s, Basic) or s.is_Atom:
  40. return s
  41. return Compound(s.__class__,
  42. tuple(deconstruct(arg, variables) for arg in s.args))
  43. def construct(t):
  44. """ Turn a Compound into a SymPy object """
  45. if isinstance(t, (Variable, CondVariable)):
  46. return t.arg
  47. if not isinstance(t, Compound):
  48. return t
  49. if any(issubclass(t.op, cls) for cls in eval_false_legal):
  50. return t.op(*map(construct, t.args), evaluate=False)
  51. elif any(issubclass(t.op, cls) for cls in basic_new_legal):
  52. return Basic.__new__(t.op, *map(construct, t.args))
  53. else:
  54. return t.op(*map(construct, t.args))
  55. def rebuild(s):
  56. """ Rebuild a SymPy expression.
  57. This removes harm caused by Expr-Rules interactions.
  58. """
  59. return construct(deconstruct(s))
  60. def unify(x, y, s=None, variables=(), **kwargs):
  61. """ Structural unification of two expressions/patterns.
  62. Examples
  63. ========
  64. >>> from sympy.unify.usympy import unify
  65. >>> from sympy import Basic, S
  66. >>> from sympy.abc import x, y, z, p, q
  67. >>> next(unify(Basic(S(1), S(2)), Basic(S(1), x), variables=[x]))
  68. {x: 2}
  69. >>> expr = 2*x + y + z
  70. >>> pattern = 2*p + q
  71. >>> next(unify(expr, pattern, {}, variables=(p, q)))
  72. {p: x, q: y + z}
  73. Unification supports commutative and associative matching
  74. >>> expr = x + y + z
  75. >>> pattern = p + q
  76. >>> len(list(unify(expr, pattern, {}, variables=(p, q))))
  77. 12
  78. Symbols not indicated to be variables are treated as literal,
  79. else they are wild-like and match anything in a sub-expression.
  80. >>> expr = x*y*z + 3
  81. >>> pattern = x*y + 3
  82. >>> next(unify(expr, pattern, {}, variables=[x, y]))
  83. {x: y, y: x*z}
  84. The x and y of the pattern above were in a Mul and matched factors
  85. in the Mul of expr. Here, a single symbol matches an entire term:
  86. >>> expr = x*y + 3
  87. >>> pattern = p + 3
  88. >>> next(unify(expr, pattern, {}, variables=[p]))
  89. {p: x*y}
  90. """
  91. decons = lambda x: deconstruct(x, variables)
  92. s = s or {}
  93. s = {decons(k): decons(v) for k, v in s.items()}
  94. ds = core.unify(decons(x), decons(y), s,
  95. is_associative=is_associative,
  96. is_commutative=is_commutative,
  97. **kwargs)
  98. for d in ds:
  99. yield {construct(k): construct(v) for k, v in d.items()}