epathtools.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356
  1. """Tools for manipulation of expressions using paths. """
  2. from sympy.core import Basic
  3. class EPath:
  4. r"""
  5. Manipulate expressions using paths.
  6. EPath grammar in EBNF notation::
  7. literal ::= /[A-Za-z_][A-Za-z_0-9]*/
  8. number ::= /-?\d+/
  9. type ::= literal
  10. attribute ::= literal "?"
  11. all ::= "*"
  12. slice ::= "[" number? (":" number? (":" number?)?)? "]"
  13. range ::= all | slice
  14. query ::= (type | attribute) ("|" (type | attribute))*
  15. selector ::= range | query range?
  16. path ::= "/" selector ("/" selector)*
  17. See the docstring of the epath() function.
  18. """
  19. __slots__ = ("_path", "_epath")
  20. def __new__(cls, path):
  21. """Construct new EPath. """
  22. if isinstance(path, EPath):
  23. return path
  24. if not path:
  25. raise ValueError("empty EPath")
  26. _path = path
  27. if path[0] == '/':
  28. path = path[1:]
  29. else:
  30. raise NotImplementedError("non-root EPath")
  31. epath = []
  32. for selector in path.split('/'):
  33. selector = selector.strip()
  34. if not selector:
  35. raise ValueError("empty selector")
  36. index = 0
  37. for c in selector:
  38. if c.isalnum() or c in ('_', '|', '?'):
  39. index += 1
  40. else:
  41. break
  42. attrs = []
  43. types = []
  44. if index:
  45. elements = selector[:index]
  46. selector = selector[index:]
  47. for element in elements.split('|'):
  48. element = element.strip()
  49. if not element:
  50. raise ValueError("empty element")
  51. if element.endswith('?'):
  52. attrs.append(element[:-1])
  53. else:
  54. types.append(element)
  55. span = None
  56. if selector == '*':
  57. pass
  58. else:
  59. if selector.startswith('['):
  60. try:
  61. i = selector.index(']')
  62. except ValueError:
  63. raise ValueError("expected ']', got EOL")
  64. _span, span = selector[1:i], []
  65. if ':' not in _span:
  66. span = int(_span)
  67. else:
  68. for elt in _span.split(':', 3):
  69. if not elt:
  70. span.append(None)
  71. else:
  72. span.append(int(elt))
  73. span = slice(*span)
  74. selector = selector[i + 1:]
  75. if selector:
  76. raise ValueError("trailing characters in selector")
  77. epath.append((attrs, types, span))
  78. obj = object.__new__(cls)
  79. obj._path = _path
  80. obj._epath = epath
  81. return obj
  82. def __repr__(self):
  83. return "%s(%r)" % (self.__class__.__name__, self._path)
  84. def _get_ordered_args(self, expr):
  85. """Sort ``expr.args`` using printing order. """
  86. if expr.is_Add:
  87. return expr.as_ordered_terms()
  88. elif expr.is_Mul:
  89. return expr.as_ordered_factors()
  90. else:
  91. return expr.args
  92. def _hasattrs(self, expr, attrs):
  93. """Check if ``expr`` has any of ``attrs``. """
  94. for attr in attrs:
  95. if not hasattr(expr, attr):
  96. return False
  97. return True
  98. def _hastypes(self, expr, types):
  99. """Check if ``expr`` is any of ``types``. """
  100. _types = [ cls.__name__ for cls in expr.__class__.mro() ]
  101. return bool(set(_types).intersection(types))
  102. def _has(self, expr, attrs, types):
  103. """Apply ``_hasattrs`` and ``_hastypes`` to ``expr``. """
  104. if not (attrs or types):
  105. return True
  106. if attrs and self._hasattrs(expr, attrs):
  107. return True
  108. if types and self._hastypes(expr, types):
  109. return True
  110. return False
  111. def apply(self, expr, func, args=None, kwargs=None):
  112. """
  113. Modify parts of an expression selected by a path.
  114. Examples
  115. ========
  116. >>> from sympy.simplify.epathtools import EPath
  117. >>> from sympy import sin, cos, E
  118. >>> from sympy.abc import x, y, z, t
  119. >>> path = EPath("/*/[0]/Symbol")
  120. >>> expr = [((x, 1), 2), ((3, y), z)]
  121. >>> path.apply(expr, lambda expr: expr**2)
  122. [((x**2, 1), 2), ((3, y**2), z)]
  123. >>> path = EPath("/*/*/Symbol")
  124. >>> expr = t + sin(x + 1) + cos(x + y + E)
  125. >>> path.apply(expr, lambda expr: 2*expr)
  126. t + sin(2*x + 1) + cos(2*x + 2*y + E)
  127. """
  128. def _apply(path, expr, func):
  129. if not path:
  130. return func(expr)
  131. else:
  132. selector, path = path[0], path[1:]
  133. attrs, types, span = selector
  134. if isinstance(expr, Basic):
  135. if not expr.is_Atom:
  136. args, basic = self._get_ordered_args(expr), True
  137. else:
  138. return expr
  139. elif hasattr(expr, '__iter__'):
  140. args, basic = expr, False
  141. else:
  142. return expr
  143. args = list(args)
  144. if span is not None:
  145. if isinstance(span, slice):
  146. indices = range(*span.indices(len(args)))
  147. else:
  148. indices = [span]
  149. else:
  150. indices = range(len(args))
  151. for i in indices:
  152. try:
  153. arg = args[i]
  154. except IndexError:
  155. continue
  156. if self._has(arg, attrs, types):
  157. args[i] = _apply(path, arg, func)
  158. if basic:
  159. return expr.func(*args)
  160. else:
  161. return expr.__class__(args)
  162. _args, _kwargs = args or (), kwargs or {}
  163. _func = lambda expr: func(expr, *_args, **_kwargs)
  164. return _apply(self._epath, expr, _func)
  165. def select(self, expr):
  166. """
  167. Retrieve parts of an expression selected by a path.
  168. Examples
  169. ========
  170. >>> from sympy.simplify.epathtools import EPath
  171. >>> from sympy import sin, cos, E
  172. >>> from sympy.abc import x, y, z, t
  173. >>> path = EPath("/*/[0]/Symbol")
  174. >>> expr = [((x, 1), 2), ((3, y), z)]
  175. >>> path.select(expr)
  176. [x, y]
  177. >>> path = EPath("/*/*/Symbol")
  178. >>> expr = t + sin(x + 1) + cos(x + y + E)
  179. >>> path.select(expr)
  180. [x, x, y]
  181. """
  182. result = []
  183. def _select(path, expr):
  184. if not path:
  185. result.append(expr)
  186. else:
  187. selector, path = path[0], path[1:]
  188. attrs, types, span = selector
  189. if isinstance(expr, Basic):
  190. args = self._get_ordered_args(expr)
  191. elif hasattr(expr, '__iter__'):
  192. args = expr
  193. else:
  194. return
  195. if span is not None:
  196. if isinstance(span, slice):
  197. args = args[span]
  198. else:
  199. try:
  200. args = [args[span]]
  201. except IndexError:
  202. return
  203. for arg in args:
  204. if self._has(arg, attrs, types):
  205. _select(path, arg)
  206. _select(self._epath, expr)
  207. return result
  208. def epath(path, expr=None, func=None, args=None, kwargs=None):
  209. r"""
  210. Manipulate parts of an expression selected by a path.
  211. Explanation
  212. ===========
  213. This function allows to manipulate large nested expressions in single
  214. line of code, utilizing techniques to those applied in XML processing
  215. standards (e.g. XPath).
  216. If ``func`` is ``None``, :func:`epath` retrieves elements selected by
  217. the ``path``. Otherwise it applies ``func`` to each matching element.
  218. Note that it is more efficient to create an EPath object and use the select
  219. and apply methods of that object, since this will compile the path string
  220. only once. This function should only be used as a convenient shortcut for
  221. interactive use.
  222. This is the supported syntax:
  223. * select all: ``/*``
  224. Equivalent of ``for arg in args:``.
  225. * select slice: ``/[0]`` or ``/[1:5]`` or ``/[1:5:2]``
  226. Supports standard Python's slice syntax.
  227. * select by type: ``/list`` or ``/list|tuple``
  228. Emulates ``isinstance()``.
  229. * select by attribute: ``/__iter__?``
  230. Emulates ``hasattr()``.
  231. Parameters
  232. ==========
  233. path : str | EPath
  234. A path as a string or a compiled EPath.
  235. expr : Basic | iterable
  236. An expression or a container of expressions.
  237. func : callable (optional)
  238. A callable that will be applied to matching parts.
  239. args : tuple (optional)
  240. Additional positional arguments to ``func``.
  241. kwargs : dict (optional)
  242. Additional keyword arguments to ``func``.
  243. Examples
  244. ========
  245. >>> from sympy.simplify.epathtools import epath
  246. >>> from sympy import sin, cos, E
  247. >>> from sympy.abc import x, y, z, t
  248. >>> path = "/*/[0]/Symbol"
  249. >>> expr = [((x, 1), 2), ((3, y), z)]
  250. >>> epath(path, expr)
  251. [x, y]
  252. >>> epath(path, expr, lambda expr: expr**2)
  253. [((x**2, 1), 2), ((3, y**2), z)]
  254. >>> path = "/*/*/Symbol"
  255. >>> expr = t + sin(x + 1) + cos(x + y + E)
  256. >>> epath(path, expr)
  257. [x, x, y]
  258. >>> epath(path, expr, lambda expr: 2*expr)
  259. t + sin(2*x + 1) + cos(2*x + 2*y + E)
  260. """
  261. _epath = EPath(path)
  262. if expr is None:
  263. return _epath
  264. if func is None:
  265. return _epath.select(expr)
  266. else:
  267. return _epath.apply(expr, func, args, kwargs)