theanocode.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. """
  2. .. deprecated:: 1.8
  3. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  4. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  5. :ref:`theanocode-deprecated` for more information.
  6. """
  7. from typing import Any, Dict as tDict
  8. from sympy.external import import_module
  9. from sympy.printing.printer import Printer
  10. from sympy.utilities.iterables import is_sequence
  11. import sympy
  12. from functools import partial
  13. from sympy.utilities.decorator import doctest_depends_on
  14. from sympy.utilities.exceptions import sympy_deprecation_warning
  15. theano = import_module('theano')
  16. if theano:
  17. ts = theano.scalar
  18. tt = theano.tensor
  19. from theano.sandbox import linalg as tlinalg
  20. mapping = {
  21. sympy.Add: tt.add,
  22. sympy.Mul: tt.mul,
  23. sympy.Abs: tt.abs_,
  24. sympy.sign: tt.sgn,
  25. sympy.ceiling: tt.ceil,
  26. sympy.floor: tt.floor,
  27. sympy.log: tt.log,
  28. sympy.exp: tt.exp,
  29. sympy.sqrt: tt.sqrt,
  30. sympy.cos: tt.cos,
  31. sympy.acos: tt.arccos,
  32. sympy.sin: tt.sin,
  33. sympy.asin: tt.arcsin,
  34. sympy.tan: tt.tan,
  35. sympy.atan: tt.arctan,
  36. sympy.atan2: tt.arctan2,
  37. sympy.cosh: tt.cosh,
  38. sympy.acosh: tt.arccosh,
  39. sympy.sinh: tt.sinh,
  40. sympy.asinh: tt.arcsinh,
  41. sympy.tanh: tt.tanh,
  42. sympy.atanh: tt.arctanh,
  43. sympy.re: tt.real,
  44. sympy.im: tt.imag,
  45. sympy.arg: tt.angle,
  46. sympy.erf: tt.erf,
  47. sympy.gamma: tt.gamma,
  48. sympy.loggamma: tt.gammaln,
  49. sympy.Pow: tt.pow,
  50. sympy.Eq: tt.eq,
  51. sympy.StrictGreaterThan: tt.gt,
  52. sympy.StrictLessThan: tt.lt,
  53. sympy.LessThan: tt.le,
  54. sympy.GreaterThan: tt.ge,
  55. sympy.And: tt.and_,
  56. sympy.Or: tt.or_,
  57. sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2
  58. sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2
  59. sympy.conjugate: tt.conj,
  60. sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1),
  61. # Matrices
  62. sympy.MatAdd: tt.Elemwise(ts.add),
  63. sympy.HadamardProduct: tt.Elemwise(ts.mul),
  64. sympy.Trace: tlinalg.trace,
  65. sympy.Determinant : tlinalg.det,
  66. sympy.Inverse: tlinalg.matrix_inverse,
  67. sympy.Transpose: tt.DimShuffle((False, False), [1, 0]),
  68. }
  69. class TheanoPrinter(Printer):
  70. """ Code printer which creates Theano symbolic expression graphs.
  71. Parameters
  72. ==========
  73. cache : dict
  74. Cache dictionary to use. If None (default) will use
  75. the global cache. To create a printer which does not depend on or alter
  76. global state pass an empty dictionary. Note: the dictionary is not
  77. copied on initialization of the printer and will be updated in-place,
  78. so using the same dict object when creating multiple printers or making
  79. multiple calls to :func:`.theano_code` or :func:`.theano_function` means
  80. the cache is shared between all these applications.
  81. Attributes
  82. ==========
  83. cache : dict
  84. A cache of Theano variables which have been created for SymPy
  85. symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
  86. :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
  87. ensure that all references to a given symbol in an expression (or
  88. multiple expressions) are printed as the same Theano variable, which is
  89. created only once. Symbols are differentiated only by name and type. The
  90. format of the cache's contents should be considered opaque to the user.
  91. """
  92. printmethod = "_theano"
  93. def __init__(self, *args, **kwargs):
  94. self.cache = kwargs.pop('cache', dict())
  95. super().__init__(*args, **kwargs)
  96. def _get_key(self, s, name=None, dtype=None, broadcastable=None):
  97. """ Get the cache key for a SymPy object.
  98. Parameters
  99. ==========
  100. s : sympy.core.basic.Basic
  101. SymPy object to get key for.
  102. name : str
  103. Name of object, if it does not have a ``name`` attribute.
  104. """
  105. if name is None:
  106. name = s.name
  107. return (name, type(s), s.args, dtype, broadcastable)
  108. def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
  109. """
  110. Get the Theano variable for a SymPy symbol from the cache, or create it
  111. if it does not exist.
  112. """
  113. # Defaults
  114. if name is None:
  115. name = s.name
  116. if dtype is None:
  117. dtype = 'floatX'
  118. if broadcastable is None:
  119. broadcastable = ()
  120. key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
  121. if key in self.cache:
  122. return self.cache[key]
  123. value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
  124. self.cache[key] = value
  125. return value
  126. def _print_Symbol(self, s, **kwargs):
  127. dtype = kwargs.get('dtypes', {}).get(s)
  128. bc = kwargs.get('broadcastables', {}).get(s)
  129. return self._get_or_create(s, dtype=dtype, broadcastable=bc)
  130. def _print_AppliedUndef(self, s, **kwargs):
  131. name = str(type(s)) + '_' + str(s.args[0])
  132. dtype = kwargs.get('dtypes', {}).get(s)
  133. bc = kwargs.get('broadcastables', {}).get(s)
  134. return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
  135. def _print_Basic(self, expr, **kwargs):
  136. op = mapping[type(expr)]
  137. children = [self._print(arg, **kwargs) for arg in expr.args]
  138. return op(*children)
  139. def _print_Number(self, n, **kwargs):
  140. # Integers already taken care of below, interpret as float
  141. return float(n.evalf())
  142. def _print_MatrixSymbol(self, X, **kwargs):
  143. dtype = kwargs.get('dtypes', {}).get(X)
  144. return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
  145. def _print_DenseMatrix(self, X, **kwargs):
  146. if not hasattr(tt, 'stacklists'):
  147. raise NotImplementedError(
  148. "Matrix translation not yet supported in this version of Theano")
  149. return tt.stacklists([
  150. [self._print(arg, **kwargs) for arg in L]
  151. for L in X.tolist()
  152. ])
  153. _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
  154. def _print_MatMul(self, expr, **kwargs):
  155. children = [self._print(arg, **kwargs) for arg in expr.args]
  156. result = children[0]
  157. for child in children[1:]:
  158. result = tt.dot(result, child)
  159. return result
  160. def _print_MatPow(self, expr, **kwargs):
  161. children = [self._print(arg, **kwargs) for arg in expr.args]
  162. result = 1
  163. if isinstance(children[1], int) and children[1] > 0:
  164. for i in range(children[1]):
  165. result = tt.dot(result, children[0])
  166. else:
  167. raise NotImplementedError('''Only non-negative integer
  168. powers of matrices can be handled by Theano at the moment''')
  169. return result
  170. def _print_MatrixSlice(self, expr, **kwargs):
  171. parent = self._print(expr.parent, **kwargs)
  172. rowslice = self._print(slice(*expr.rowslice), **kwargs)
  173. colslice = self._print(slice(*expr.colslice), **kwargs)
  174. return parent[rowslice, colslice]
  175. def _print_BlockMatrix(self, expr, **kwargs):
  176. nrows, ncols = expr.blocks.shape
  177. blocks = [[self._print(expr.blocks[r, c], **kwargs)
  178. for c in range(ncols)]
  179. for r in range(nrows)]
  180. return tt.join(0, *[tt.join(1, *row) for row in blocks])
  181. def _print_slice(self, expr, **kwargs):
  182. return slice(*[self._print(i, **kwargs)
  183. if isinstance(i, sympy.Basic) else i
  184. for i in (expr.start, expr.stop, expr.step)])
  185. def _print_Pi(self, expr, **kwargs):
  186. return 3.141592653589793
  187. def _print_Exp1(self, expr, **kwargs):
  188. return ts.exp(1)
  189. def _print_Piecewise(self, expr, **kwargs):
  190. import numpy as np
  191. e, cond = expr.args[0].args # First condition and corresponding value
  192. # Print conditional expression and value for first condition
  193. p_cond = self._print(cond, **kwargs)
  194. p_e = self._print(e, **kwargs)
  195. # One condition only
  196. if len(expr.args) == 1:
  197. # Return value if condition else NaN
  198. return tt.switch(p_cond, p_e, np.nan)
  199. # Return value_1 if condition_1 else evaluate remaining conditions
  200. p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
  201. return tt.switch(p_cond, p_e, p_remaining)
  202. def _print_Rational(self, expr, **kwargs):
  203. return tt.true_div(self._print(expr.p, **kwargs),
  204. self._print(expr.q, **kwargs))
  205. def _print_Integer(self, expr, **kwargs):
  206. return expr.p
  207. def _print_factorial(self, expr, **kwargs):
  208. return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
  209. def _print_Derivative(self, deriv, **kwargs):
  210. rv = self._print(deriv.expr, **kwargs)
  211. for var in deriv.variables:
  212. var = self._print(var, **kwargs)
  213. rv = tt.Rop(rv, var, tt.ones_like(var))
  214. return rv
  215. def emptyPrinter(self, expr):
  216. return expr
  217. def doprint(self, expr, dtypes=None, broadcastables=None):
  218. """ Convert a SymPy expression to a Theano graph variable.
  219. The ``dtypes`` and ``broadcastables`` arguments are used to specify the
  220. data type, dimension, and broadcasting behavior of the Theano variables
  221. corresponding to the free symbols in ``expr``. Each is a mapping from
  222. SymPy symbols to the value of the corresponding argument to
  223. ``theano.tensor.Tensor``.
  224. See the corresponding `documentation page`__ for more information on
  225. broadcasting in Theano.
  226. .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html
  227. Parameters
  228. ==========
  229. expr : sympy.core.expr.Expr
  230. SymPy expression to print.
  231. dtypes : dict
  232. Mapping from SymPy symbols to Theano datatypes to use when creating
  233. new Theano variables for those symbols. Corresponds to the ``dtype``
  234. argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'``
  235. for symbols not included in the mapping.
  236. broadcastables : dict
  237. Mapping from SymPy symbols to the value of the ``broadcastable``
  238. argument to ``theano.tensor.Tensor`` to use when creating Theano
  239. variables for those symbols. Defaults to the empty tuple for symbols
  240. not included in the mapping (resulting in a scalar).
  241. Returns
  242. =======
  243. theano.gof.graph.Variable
  244. A variable corresponding to the expression's value in a Theano
  245. symbolic expression graph.
  246. """
  247. if dtypes is None:
  248. dtypes = {}
  249. if broadcastables is None:
  250. broadcastables = {}
  251. return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
  252. global_cache = {} # type: tDict[Any, Any]
  253. def theano_code(expr, cache=None, **kwargs):
  254. """
  255. Convert a SymPy expression into a Theano graph variable.
  256. .. deprecated:: 1.8
  257. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  258. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  259. :ref:`theanocode-deprecated` for more information.
  260. Parameters
  261. ==========
  262. expr : sympy.core.expr.Expr
  263. SymPy expression object to convert.
  264. cache : dict
  265. Cached Theano variables (see :class:`TheanoPrinter.cache
  266. <TheanoPrinter>`). Defaults to the module-level global cache.
  267. dtypes : dict
  268. Passed to :meth:`.TheanoPrinter.doprint`.
  269. broadcastables : dict
  270. Passed to :meth:`.TheanoPrinter.doprint`.
  271. Returns
  272. =======
  273. theano.gof.graph.Variable
  274. A variable corresponding to the expression's value in a Theano symbolic
  275. expression graph.
  276. """
  277. sympy_deprecation_warning(
  278. """
  279. sympy.printing.theanocode is deprecated. Theano has been renamed to
  280. Aesara. Use sympy.printing.aesaracode instead.""",
  281. deprecated_since_version="1.8",
  282. active_deprecations_target='theanocode-deprecated')
  283. if not theano:
  284. raise ImportError("theano is required for theano_code")
  285. if cache is None:
  286. cache = global_cache
  287. return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
  288. def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
  289. r"""
  290. Get value of ``broadcastables`` argument to :func:`.theano_code` from
  291. keyword arguments to :func:`.theano_function`.
  292. Included for backwards compatibility.
  293. Parameters
  294. ==========
  295. inputs
  296. Sequence of input symbols.
  297. dim : int
  298. Common number of dimensions for all inputs. Overrides other arguments
  299. if given.
  300. dims : dict
  301. Mapping from input symbols to number of dimensions. Overrides
  302. ``broadcastables`` argument if given.
  303. broadcastables : dict
  304. Explicit value of ``broadcastables`` argument to
  305. :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged.
  306. Returns
  307. =======
  308. dict
  309. Dictionary mapping elements of ``inputs`` to their "broadcastable"
  310. values (tuple of ``bool``\ s).
  311. """
  312. if dim is not None:
  313. return {s: (False,) * dim for s in inputs}
  314. if dims is not None:
  315. maxdim = max(dims.values())
  316. return {
  317. s: (False,) * d + (True,) * (maxdim - d)
  318. for s, d in dims.items()
  319. }
  320. if broadcastables is not None:
  321. return broadcastables
  322. return {}
  323. @doctest_depends_on(modules=('theano',))
  324. def theano_function(inputs, outputs, scalar=False, *,
  325. dim=None, dims=None, broadcastables=None, **kwargs):
  326. """
  327. Create a Theano function from SymPy expressions.
  328. .. deprecated:: 1.8
  329. ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
  330. Aesara. Use ``sympy.printing.aesaracode`` instead. See
  331. :ref:`theanocode-deprecated` for more information.
  332. The inputs and outputs are converted to Theano variables using
  333. :func:`.theano_code` and then passed to ``theano.function``.
  334. Parameters
  335. ==========
  336. inputs
  337. Sequence of symbols which constitute the inputs of the function.
  338. outputs
  339. Sequence of expressions which constitute the outputs(s) of the
  340. function. The free symbols of each expression must be a subset of
  341. ``inputs``.
  342. scalar : bool
  343. Convert 0-dimensional arrays in output to scalars. This will return a
  344. Python wrapper function around the Theano function object.
  345. cache : dict
  346. Cached Theano variables (see :class:`TheanoPrinter.cache
  347. <TheanoPrinter>`). Defaults to the module-level global cache.
  348. dtypes : dict
  349. Passed to :meth:`.TheanoPrinter.doprint`.
  350. broadcastables : dict
  351. Passed to :meth:`.TheanoPrinter.doprint`.
  352. dims : dict
  353. Alternative to ``broadcastables`` argument. Mapping from elements of
  354. ``inputs`` to integers indicating the dimension of their associated
  355. arrays/tensors. Overrides ``broadcastables`` argument if given.
  356. dim : int
  357. Another alternative to the ``broadcastables`` argument. Common number of
  358. dimensions to use for all arrays/tensors.
  359. ``theano_function([x, y], [...], dim=2)`` is equivalent to using
  360. ``broadcastables={x: (False, False), y: (False, False)}``.
  361. Returns
  362. =======
  363. callable
  364. A callable object which takes values of ``inputs`` as positional
  365. arguments and returns an output array for each of the expressions
  366. in ``outputs``. If ``outputs`` is a single expression the function will
  367. return a Numpy array, if it is a list of multiple expressions the
  368. function will return a list of arrays. See description of the ``squeeze``
  369. argument above for the behavior when a single output is passed in a list.
  370. The returned object will either be an instance of
  371. ``theano.compile.function_module.Function`` or a Python wrapper
  372. function around one. In both cases, the returned value will have a
  373. ``theano_function`` attribute which points to the return value of
  374. ``theano.function``.
  375. Examples
  376. ========
  377. >>> from sympy.abc import x, y, z
  378. >>> from sympy.printing.theanocode import theano_function
  379. A simple function with one input and one output:
  380. >>> f1 = theano_function([x], [x**2 - 1], scalar=True)
  381. >>> f1(3)
  382. 8.0
  383. A function with multiple inputs and one output:
  384. >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
  385. >>> f2(3, 4, 2)
  386. 5.0
  387. A function with multiple inputs and multiple outputs:
  388. >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
  389. >>> f3(2, 3)
  390. [13.0, -5.0]
  391. See also
  392. ========
  393. dim_handling
  394. """
  395. sympy_deprecation_warning(
  396. """
  397. sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""",
  398. deprecated_since_version="1.8",
  399. active_deprecations_target='theanocode-deprecated')
  400. if not theano:
  401. raise ImportError("theano is required for theano_function")
  402. # Pop off non-theano keyword args
  403. cache = kwargs.pop('cache', {})
  404. dtypes = kwargs.pop('dtypes', {})
  405. broadcastables = dim_handling(
  406. inputs, dim=dim, dims=dims, broadcastables=broadcastables,
  407. )
  408. # Print inputs/outputs
  409. code = partial(theano_code, cache=cache, dtypes=dtypes,
  410. broadcastables=broadcastables)
  411. tinputs = list(map(code, inputs))
  412. toutputs = list(map(code, outputs))
  413. #fix constant expressions as variables
  414. toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs]
  415. if len(toutputs) == 1:
  416. toutputs = toutputs[0]
  417. # Compile theano func
  418. func = theano.function(tinputs, toutputs, **kwargs)
  419. is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
  420. # No wrapper required
  421. if not scalar or not any(is_0d):
  422. func.theano_function = func
  423. return func
  424. # Create wrapper to convert 0-dimensional outputs to scalars
  425. def wrapper(*args):
  426. out = func(*args)
  427. # out can be array(1.0) or [array(1.0), array(2.0)]
  428. if is_sequence(out):
  429. return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
  430. else:
  431. return out[()]
  432. wrapper.__wrapped__ = func
  433. wrapper.__doc__ = func.__doc__
  434. wrapper.theano_function = func
  435. return wrapper