aesaracode.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. from typing import Any, Dict as tDict
  2. from sympy.external import import_module
  3. from sympy.printing.printer import Printer
  4. from sympy.utilities.iterables import is_sequence
  5. import sympy
  6. from functools import partial
  7. aesara = import_module('aesara')
  8. if aesara:
  9. aes = aesara.scalar
  10. aet = aesara.tensor
  11. from aesara.tensor import nlinalg
  12. from aesara.tensor.elemwise import Elemwise
  13. from aesara.tensor.elemwise import DimShuffle
  14. mapping = {
  15. sympy.Add: aet.add,
  16. sympy.Mul: aet.mul,
  17. sympy.Abs: aet.abs_,
  18. sympy.sign: aet.sgn,
  19. sympy.ceiling: aet.ceil,
  20. sympy.floor: aet.floor,
  21. sympy.log: aet.log,
  22. sympy.exp: aet.exp,
  23. sympy.sqrt: aet.sqrt,
  24. sympy.cos: aet.cos,
  25. sympy.acos: aet.arccos,
  26. sympy.sin: aet.sin,
  27. sympy.asin: aet.arcsin,
  28. sympy.tan: aet.tan,
  29. sympy.atan: aet.arctan,
  30. sympy.atan2: aet.arctan2,
  31. sympy.cosh: aet.cosh,
  32. sympy.acosh: aet.arccosh,
  33. sympy.sinh: aet.sinh,
  34. sympy.asinh: aet.arcsinh,
  35. sympy.tanh: aet.tanh,
  36. sympy.atanh: aet.arctanh,
  37. sympy.re: aet.real,
  38. sympy.im: aet.imag,
  39. sympy.arg: aet.angle,
  40. sympy.erf: aet.erf,
  41. sympy.gamma: aet.gamma,
  42. sympy.loggamma: aet.gammaln,
  43. sympy.Pow: aet.pow,
  44. sympy.Eq: aet.eq,
  45. sympy.StrictGreaterThan: aet.gt,
  46. sympy.StrictLessThan: aet.lt,
  47. sympy.LessThan: aet.le,
  48. sympy.GreaterThan: aet.ge,
  49. sympy.And: aet.and_, # bitwise
  50. sympy.Or: aet.or_, # bitwise
  51. sympy.Not: aet.invert, # bitwise
  52. sympy.Xor: aet.xor, # bitwise
  53. sympy.Max: aet.maximum, # Sympy accept >2 inputs, Aesara only 2
  54. sympy.Min: aet.minimum, # Sympy accept >2 inputs, Aesara only 2
  55. sympy.conjugate: aet.conj,
  56. sympy.core.numbers.ImaginaryUnit: lambda:aet.complex(0,1),
  57. # Matrices
  58. sympy.MatAdd: Elemwise(aes.add),
  59. sympy.HadamardProduct: Elemwise(aes.mul),
  60. sympy.Trace: nlinalg.trace,
  61. sympy.Determinant : nlinalg.det,
  62. sympy.Inverse: nlinalg.matrix_inverse,
  63. sympy.Transpose: DimShuffle((False, False), [1, 0]),
  64. }
  65. class AesaraPrinter(Printer):
  66. """ Code printer which creates Aesara symbolic expression graphs.
  67. Parameters
  68. ==========
  69. cache : dict
  70. Cache dictionary to use. If None (default) will use
  71. the global cache. To create a printer which does not depend on or alter
  72. global state pass an empty dictionary. Note: the dictionary is not
  73. copied on initialization of the printer and will be updated in-place,
  74. so using the same dict object when creating multiple printers or making
  75. multiple calls to :func:`.aesara_code` or :func:`.aesara_function` means
  76. the cache is shared between all these applications.
  77. Attributes
  78. ==========
  79. cache : dict
  80. A cache of Aesara variables which have been created for SymPy
  81. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  82. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  83. ensure that all references to a given symbol in an expression (or
  84. multiple expressions) are printed as the same Aesara variable, which is
  85. created only once. Symbols are differentiated only by name and type. The
  86. format of the cache's contents should be considered opaque to the user.
  87. """
  88. printmethod = "_aesara"
  89. def __init__(self, *args, **kwargs):
  90. self.cache = kwargs.pop('cache', dict())
  91. super().__init__(*args, **kwargs)
  92. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  93. """ Get the cache key for a SymPy object.
  94. Parameters
  95. ==========
  96. s : sympy.core.basic.Basic
  97. SymPy object to get key for.
  98. name : str
  99. Name of object, if it does not have a ``name`` attribute.
  100. """
  101. if name is None:
  102. name = s.name
  103. return (name, type(s), s.args, dtype, broadcastable)
  104. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  105. """
  106. Get the Aesara variable for a SymPy symbol from the cache, or create it
  107. if it does not exist.
  108. """
  109. # Defaults
  110. if name is None:
  111. name = s.name
  112. if dtype is None:
  113. dtype = 'floatX'
  114. if broadcastable is None:
  115. broadcastable = ()
  116. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  117. if key in self.cache:
  118. return self.cache[key]
  119. value = aet.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
  120. self.cache[key] = value
  121. return value
  122. def _print_Symbol(self, s, **kwargs):
  123. dtype = kwargs.get('dtypes', {}).get(s)
  124. bc = kwargs.get('broadcastables', {}).get(s)
  125. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  126. def _print_AppliedUndef(self, s, **kwargs):
  127. name = str(type(s)) + '_' + str(s.args[0])
  128. dtype = kwargs.get('dtypes', {}).get(s)
  129. bc = kwargs.get('broadcastables', {}).get(s)
  130. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  131. def _print_Basic(self, expr, **kwargs):
  132. op = mapping[type(expr)]
  133. children = [self._print(arg, **kwargs) for arg in expr.args]
  134. return op(*children)
  135. def _print_Number(self, n, **kwargs):
  136. # Integers already taken care of below, interpret as float
  137. return float(n.evalf())
  138. def _print_MatrixSymbol(self, X, **kwargs):
  139. dtype = kwargs.get('dtypes', {}).get(X)
  140. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  141. def _print_DenseMatrix(self, X, **kwargs):
  142. if not hasattr(aet, 'stacklists'):
  143. raise NotImplementedError(
  144. "Matrix translation not yet supported in this version of Aesara")
  145. return aet.stacklists([
  146. [self._print(arg, **kwargs) for arg in L]
  147. for L in X.tolist()
  148. ])
  149. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  150. def _print_MatMul(self, expr, **kwargs):
  151. children = [self._print(arg, **kwargs) for arg in expr.args]
  152. result = children[0]
  153. for child in children[1:]:
  154. result = aet.dot(result, child)
  155. return result
  156. def _print_MatPow(self, expr, **kwargs):
  157. children = [self._print(arg, **kwargs) for arg in expr.args]
  158. result = 1
  159. if isinstance(children[1], int) and children[1] > 0:
  160. for i in range(children[1]):
  161. result = aet.dot(result, children[0])
  162. else:
  163. raise NotImplementedError('''Only non-negative integer
  164. powers of matrices can be handled by Aesara at the moment''')
  165. return result
  166. def _print_MatrixSlice(self, expr, **kwargs):
  167. parent = self._print(expr.parent, **kwargs)
  168. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  169. colslice = self._print(slice(*expr.colslice), **kwargs)
  170. return parent[rowslice, colslice]
  171. def _print_BlockMatrix(self, expr, **kwargs):
  172. nrows, ncols = expr.blocks.shape
  173. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  174. for c in range(ncols)]
  175. for r in range(nrows)]
  176. return aet.join(0, *[aet.join(1, *row) for row in blocks])
  177. def _print_slice(self, expr, **kwargs):
  178. return slice(*[self._print(i, **kwargs)
  179. if isinstance(i, sympy.Basic) else i
  180. for i in (expr.start, expr.stop, expr.step)])
  181. def _print_Pi(self, expr, **kwargs):
  182. return 3.141592653589793
  183. def _print_Piecewise(self, expr, **kwargs):
  184. import numpy as np
  185. e, cond = expr.args[0].args # First condition and corresponding value
  186. # Print conditional expression and value for first condition
  187. p_cond = self._print(cond, **kwargs)
  188. p_e = self._print(e, **kwargs)
  189. # One condition only
  190. if len(expr.args) == 1:
  191. # Return value if condition else NaN
  192. return aet.switch(p_cond, p_e, np.nan)
  193. # Return value_1 if condition_1 else evaluate remaining conditions
  194. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  195. return aet.switch(p_cond, p_e, p_remaining)
  196. def _print_Rational(self, expr, **kwargs):
  197. return aet.true_div(self._print(expr.p, **kwargs),
  198. self._print(expr.q, **kwargs))
  199. def _print_Integer(self, expr, **kwargs):
  200. return expr.p
  201. def _print_factorial(self, expr, **kwargs):
  202. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  203. def _print_Derivative(self, deriv, **kwargs):
  204. from aesara.gradient import Rop
  205. rv = self._print(deriv.expr, **kwargs)
  206. for var in deriv.variables:
  207. var = self._print(var, **kwargs)
  208. rv = Rop(rv, var, aet.ones_like(var))
  209. return rv
  210. def emptyPrinter(self, expr):
  211. return expr
  212. def doprint(self, expr, dtypes=None, broadcastables=None):
  213. """ Convert a SymPy expression to a Aesara graph variable.
  214. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  215. data type, dimension, and broadcasting behavior of the Aesara variables
  216. corresponding to the free symbols in ``expr``. Each is a mapping from
  217. SymPy symbols to the value of the corresponding argument to
  218. ``aesara.tensor.var.TensorVariable``.
  219. See the corresponding `documentation page`__ for more information on
  220. broadcasting in Aesara.
  221. .. __: https://aesara.readthedocs.io/en/latest/tutorial/broadcasting.html
  222. Parameters
  223. ==========
  224. expr : sympy.core.expr.Expr
  225. SymPy expression to print.
  226. dtypes : dict
  227. Mapping from SymPy symbols to Aesara datatypes to use when creating
  228. new Aesara variables for those symbols. Corresponds to the ``dtype``
  229. argument to ``aesara.tensor.var.TensorVariable``. Defaults to ``'floatX'``
  230. for symbols not included in the mapping.
  231. broadcastables : dict
  232. Mapping from SymPy symbols to the value of the ``broadcastable``
  233. argument to ``aesara.tensor.var.TensorVariable`` to use when creating Aesara
  234. variables for those symbols. Defaults to the empty tuple for symbols
  235. not included in the mapping (resulting in a scalar).
  236. Returns
  237. =======
  238. aesara.graph.basic.Variable
  239. A variable corresponding to the expression's value in a Aesara
  240. symbolic expression graph.
  241. """
  242. if dtypes is None:
  243. dtypes = {}
  244. if broadcastables is None:
  245. broadcastables = {}
  246. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  247. global_cache = {} # type: tDict[Any, Any]
  248. def aesara_code(expr, cache=None, **kwargs):
  249. """
  250. Convert a SymPy expression into a Aesara graph variable.
  251. Parameters
  252. ==========
  253. expr : sympy.core.expr.Expr
  254. SymPy expression object to convert.
  255. cache : dict
  256. Cached Aesara variables (see :class:`AesaraPrinter.cache
  257. <AesaraPrinter>`). Defaults to the module-level global cache.
  258. dtypes : dict
  259. Passed to :meth:`.AesaraPrinter.doprint`.
  260. broadcastables : dict
  261. Passed to :meth:`.AesaraPrinter.doprint`.
  262. Returns
  263. =======
  264. aesara.graph.basic.Variable
  265. A variable corresponding to the expression's value in a Aesara symbolic
  266. expression graph.
  267. """
  268. if not aesara:
  269. raise ImportError("aesara is required for aesara_code")
  270. if cache is None:
  271. cache = global_cache
  272. return AesaraPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  273. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  274. r"""
  275. Get value of ``broadcastables`` argument to :func:`.aesara_code` from
  276. keyword arguments to :func:`.aesara_function`.
  277. Included for backwards compatibility.
  278. Parameters
  279. ==========
  280. inputs
  281. Sequence of input symbols.
  282. dim : int
  283. Common number of dimensions for all inputs. Overrides other arguments
  284. if given.
  285. dims : dict
  286. Mapping from input symbols to number of dimensions. Overrides
  287. ``broadcastables`` argument if given.
  288. broadcastables : dict
  289. Explicit value of ``broadcastables`` argument to
  290. :meth:`.AesaraPrinter.doprint`. If not None function will return this value unchanged.
  291. Returns
  292. =======
  293. dict
  294. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  295. values (tuple of ``bool``\ s).
  296. """
  297. if dim is not None:
  298. return {s: (False,) * dim for s in inputs}
  299. if dims is not None:
  300. maxdim = max(dims.values())
  301. return {
  302. s: (False,) * d + (True,) * (maxdim - d)
  303. for s, d in dims.items()
  304. }
  305. if broadcastables is not None:
  306. return broadcastables
  307. return {}
  308. def aesara_function(inputs, outputs, scalar=False, *,
  309. dim=None, dims=None, broadcastables=None, **kwargs):
  310. """
  311. Create a Aesara function from SymPy expressions.
  312. The inputs and outputs are converted to Aesara variables using
  313. :func:`.aesara_code` and then passed to ``aesara.function``.
  314. Parameters
  315. ==========
  316. inputs
  317. Sequence of symbols which constitute the inputs of the function.
  318. outputs
  319. Sequence of expressions which constitute the outputs(s) of the
  320. function. The free symbols of each expression must be a subset of
  321. ``inputs``.
  322. scalar : bool
  323. Convert 0-dimensional arrays in output to scalars. This will return a
  324. Python wrapper function around the Aesara function object.
  325. cache : dict
  326. Cached Aesara variables (see :class:`AesaraPrinter.cache
  327. <AesaraPrinter>`). Defaults to the module-level global cache.
  328. dtypes : dict
  329. Passed to :meth:`.AesaraPrinter.doprint`.
  330. broadcastables : dict
  331. Passed to :meth:`.AesaraPrinter.doprint`.
  332. dims : dict
  333. Alternative to ``broadcastables`` argument. Mapping from elements of
  334. ``inputs`` to integers indicating the dimension of their associated
  335. arrays/tensors. Overrides ``broadcastables`` argument if given.
  336. dim : int
  337. Another alternative to the ``broadcastables`` argument. Common number of
  338. dimensions to use for all arrays/tensors.
  339. ``aesara_function([x, y], [...], dim=2)`` is equivalent to using
  340. ``broadcastables={x: (False, False), y: (False, False)}``.
  341. Returns
  342. =======
  343. callable
  344. A callable object which takes values of ``inputs`` as positional
  345. arguments and returns an output array for each of the expressions
  346. in ``outputs``. If ``outputs`` is a single expression the function will
  347. return a Numpy array, if it is a list of multiple expressions the
  348. function will return a list of arrays. See description of the ``squeeze``
  349. argument above for the behavior when a single output is passed in a list.
  350. The returned object will either be an instance of
  351. ``aesara.compile.function.types.Function`` or a Python wrapper
  352. function around one. In both cases, the returned value will have a
  353. ``aesara_function`` attribute which points to the return value of
  354. ``aesara.function``.
  355. Examples
  356. ========
  357. >>> from sympy.abc import x, y, z
  358. >>> from sympy.printing.aesaracode import aesara_function
  359. A simple function with one input and one output:
  360. >>> f1 = aesara_function([x], [x**2 - 1], scalar=True)
  361. >>> f1(3)
  362. 8.0
  363. A function with multiple inputs and one output:
  364. >>> f2 = aesara_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  365. >>> f2(3, 4, 2)
  366. 5.0
  367. A function with multiple inputs and multiple outputs:
  368. >>> f3 = aesara_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  369. >>> f3(2, 3)
  370. [13.0, -5.0]
  371. See also
  372. ========
  373. dim_handling
  374. """
  375. if not aesara:
  376. raise ImportError("Aesara is required for aesara_function")
  377. # Pop off non-aesara keyword args
  378. cache = kwargs.pop('cache', {})
  379. dtypes = kwargs.pop('dtypes', {})
  380. broadcastables = dim_handling(
  381. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  382. )
  383. # Print inputs/outputs
  384. code = partial(aesara_code, cache=cache, dtypes=dtypes,
  385. broadcastables=broadcastables)
  386. tinputs = list(map(code, inputs))
  387. toutputs = list(map(code, outputs))
  388. #fix constant expressions as variables
  389. toutputs = [output if isinstance(output, aesara.graph.basic.Variable) else aet.as_tensor_variable(output) for output in toutputs]
  390. if len(toutputs) == 1:
  391. toutputs = toutputs[0]
  392. # Compile aesara func
  393. func = aesara.function(tinputs, toutputs, **kwargs)
  394. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  395. # No wrapper required
  396. if not scalar or not any(is_0d):
  397. func.aesara_function = func
  398. return func
  399. # Create wrapper to convert 0-dimensional outputs to scalars
  400. def wrapper(*args):
  401. out = func(*args)
  402. # out can be array(1.0) or [array(1.0), array(2.0)]
  403. if is_sequence(out):
  404. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  405. else:
  406. return out[()]
  407. wrapper.__wrapped__ = func
  408. wrapper.__doc__ = func.__doc__
  409. wrapper.aesara_function = func
  410. return wrapper