123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564 |
- """
- .. deprecated:: 1.8
- ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
- Aesara. Use ``sympy.printing.aesaracode`` instead. See
- :ref:`theanocode-deprecated` for more information.
- """
- from typing import Any, Dict as tDict
- from sympy.external import import_module
- from sympy.printing.printer import Printer
- from sympy.utilities.iterables import is_sequence
- import sympy
- from functools import partial
- from sympy.utilities.decorator import doctest_depends_on
- from sympy.utilities.exceptions import sympy_deprecation_warning
- theano = import_module('theano')
- if theano:
- ts = theano.scalar
- tt = theano.tensor
- from theano.sandbox import linalg as tlinalg
- mapping = {
- sympy.Add: tt.add,
- sympy.Mul: tt.mul,
- sympy.Abs: tt.abs_,
- sympy.sign: tt.sgn,
- sympy.ceiling: tt.ceil,
- sympy.floor: tt.floor,
- sympy.log: tt.log,
- sympy.exp: tt.exp,
- sympy.sqrt: tt.sqrt,
- sympy.cos: tt.cos,
- sympy.acos: tt.arccos,
- sympy.sin: tt.sin,
- sympy.asin: tt.arcsin,
- sympy.tan: tt.tan,
- sympy.atan: tt.arctan,
- sympy.atan2: tt.arctan2,
- sympy.cosh: tt.cosh,
- sympy.acosh: tt.arccosh,
- sympy.sinh: tt.sinh,
- sympy.asinh: tt.arcsinh,
- sympy.tanh: tt.tanh,
- sympy.atanh: tt.arctanh,
- sympy.re: tt.real,
- sympy.im: tt.imag,
- sympy.arg: tt.angle,
- sympy.erf: tt.erf,
- sympy.gamma: tt.gamma,
- sympy.loggamma: tt.gammaln,
- sympy.Pow: tt.pow,
- sympy.Eq: tt.eq,
- sympy.StrictGreaterThan: tt.gt,
- sympy.StrictLessThan: tt.lt,
- sympy.LessThan: tt.le,
- sympy.GreaterThan: tt.ge,
- sympy.And: tt.and_,
- sympy.Or: tt.or_,
- sympy.Max: tt.maximum, # SymPy accept >2 inputs, Theano only 2
- sympy.Min: tt.minimum, # SymPy accept >2 inputs, Theano only 2
- sympy.conjugate: tt.conj,
- sympy.core.numbers.ImaginaryUnit: lambda:tt.complex(0,1),
- # Matrices
- sympy.MatAdd: tt.Elemwise(ts.add),
- sympy.HadamardProduct: tt.Elemwise(ts.mul),
- sympy.Trace: tlinalg.trace,
- sympy.Determinant : tlinalg.det,
- sympy.Inverse: tlinalg.matrix_inverse,
- sympy.Transpose: tt.DimShuffle((False, False), [1, 0]),
- }
- class TheanoPrinter(Printer):
- """ Code printer which creates Theano symbolic expression graphs.
- Parameters
- ==========
- cache : dict
- Cache dictionary to use. If None (default) will use
- the global cache. To create a printer which does not depend on or alter
- global state pass an empty dictionary. Note: the dictionary is not
- copied on initialization of the printer and will be updated in-place,
- so using the same dict object when creating multiple printers or making
- multiple calls to :func:`.theano_code` or :func:`.theano_function` means
- the cache is shared between all these applications.
- Attributes
- ==========
- cache : dict
- A cache of Theano variables which have been created for SymPy
- symbol-like objects (e.g. :class:`sympy.core.symbol.Symbol` or
- :class:`sympy.matrices.expressions.MatrixSymbol`). This is used to
- ensure that all references to a given symbol in an expression (or
- multiple expressions) are printed as the same Theano variable, which is
- created only once. Symbols are differentiated only by name and type. The
- format of the cache's contents should be considered opaque to the user.
- """
- printmethod = "_theano"
- def __init__(self, *args, **kwargs):
- self.cache = kwargs.pop('cache', dict())
- super().__init__(*args, **kwargs)
- def _get_key(self, s, name=None, dtype=None, broadcastable=None):
- """ Get the cache key for a SymPy object.
- Parameters
- ==========
- s : sympy.core.basic.Basic
- SymPy object to get key for.
- name : str
- Name of object, if it does not have a ``name`` attribute.
- """
- if name is None:
- name = s.name
- return (name, type(s), s.args, dtype, broadcastable)
- def _get_or_create(self, s, name=None, dtype=None, broadcastable=None):
- """
- Get the Theano variable for a SymPy symbol from the cache, or create it
- if it does not exist.
- """
- # Defaults
- if name is None:
- name = s.name
- if dtype is None:
- dtype = 'floatX'
- if broadcastable is None:
- broadcastable = ()
- key = self._get_key(s, name, dtype=dtype, broadcastable=broadcastable)
- if key in self.cache:
- return self.cache[key]
- value = tt.tensor(name=name, dtype=dtype, broadcastable=broadcastable)
- self.cache[key] = value
- return value
- def _print_Symbol(self, s, **kwargs):
- dtype = kwargs.get('dtypes', {}).get(s)
- bc = kwargs.get('broadcastables', {}).get(s)
- return self._get_or_create(s, dtype=dtype, broadcastable=bc)
- def _print_AppliedUndef(self, s, **kwargs):
- name = str(type(s)) + '_' + str(s.args[0])
- dtype = kwargs.get('dtypes', {}).get(s)
- bc = kwargs.get('broadcastables', {}).get(s)
- return self._get_or_create(s, name=name, dtype=dtype, broadcastable=bc)
- def _print_Basic(self, expr, **kwargs):
- op = mapping[type(expr)]
- children = [self._print(arg, **kwargs) for arg in expr.args]
- return op(*children)
- def _print_Number(self, n, **kwargs):
- # Integers already taken care of below, interpret as float
- return float(n.evalf())
- def _print_MatrixSymbol(self, X, **kwargs):
- dtype = kwargs.get('dtypes', {}).get(X)
- return self._get_or_create(X, dtype=dtype, broadcastable=(None, None))
- def _print_DenseMatrix(self, X, **kwargs):
- if not hasattr(tt, 'stacklists'):
- raise NotImplementedError(
- "Matrix translation not yet supported in this version of Theano")
- return tt.stacklists([
- [self._print(arg, **kwargs) for arg in L]
- for L in X.tolist()
- ])
- _print_ImmutableMatrix = _print_ImmutableDenseMatrix = _print_DenseMatrix
- def _print_MatMul(self, expr, **kwargs):
- children = [self._print(arg, **kwargs) for arg in expr.args]
- result = children[0]
- for child in children[1:]:
- result = tt.dot(result, child)
- return result
- def _print_MatPow(self, expr, **kwargs):
- children = [self._print(arg, **kwargs) for arg in expr.args]
- result = 1
- if isinstance(children[1], int) and children[1] > 0:
- for i in range(children[1]):
- result = tt.dot(result, children[0])
- else:
- raise NotImplementedError('''Only non-negative integer
- powers of matrices can be handled by Theano at the moment''')
- return result
- def _print_MatrixSlice(self, expr, **kwargs):
- parent = self._print(expr.parent, **kwargs)
- rowslice = self._print(slice(*expr.rowslice), **kwargs)
- colslice = self._print(slice(*expr.colslice), **kwargs)
- return parent[rowslice, colslice]
- def _print_BlockMatrix(self, expr, **kwargs):
- nrows, ncols = expr.blocks.shape
- blocks = [[self._print(expr.blocks[r, c], **kwargs)
- for c in range(ncols)]
- for r in range(nrows)]
- return tt.join(0, *[tt.join(1, *row) for row in blocks])
- def _print_slice(self, expr, **kwargs):
- return slice(*[self._print(i, **kwargs)
- if isinstance(i, sympy.Basic) else i
- for i in (expr.start, expr.stop, expr.step)])
- def _print_Pi(self, expr, **kwargs):
- return 3.141592653589793
- def _print_Exp1(self, expr, **kwargs):
- return ts.exp(1)
- def _print_Piecewise(self, expr, **kwargs):
- import numpy as np
- e, cond = expr.args[0].args # First condition and corresponding value
- # Print conditional expression and value for first condition
- p_cond = self._print(cond, **kwargs)
- p_e = self._print(e, **kwargs)
- # One condition only
- if len(expr.args) == 1:
- # Return value if condition else NaN
- return tt.switch(p_cond, p_e, np.nan)
- # Return value_1 if condition_1 else evaluate remaining conditions
- p_remaining = self._print(sympy.Piecewise(*expr.args[1:]), **kwargs)
- return tt.switch(p_cond, p_e, p_remaining)
- def _print_Rational(self, expr, **kwargs):
- return tt.true_div(self._print(expr.p, **kwargs),
- self._print(expr.q, **kwargs))
- def _print_Integer(self, expr, **kwargs):
- return expr.p
- def _print_factorial(self, expr, **kwargs):
- return self._print(sympy.gamma(expr.args[0] + 1), **kwargs)
- def _print_Derivative(self, deriv, **kwargs):
- rv = self._print(deriv.expr, **kwargs)
- for var in deriv.variables:
- var = self._print(var, **kwargs)
- rv = tt.Rop(rv, var, tt.ones_like(var))
- return rv
- def emptyPrinter(self, expr):
- return expr
- def doprint(self, expr, dtypes=None, broadcastables=None):
- """ Convert a SymPy expression to a Theano graph variable.
- The ``dtypes`` and ``broadcastables`` arguments are used to specify the
- data type, dimension, and broadcasting behavior of the Theano variables
- corresponding to the free symbols in ``expr``. Each is a mapping from
- SymPy symbols to the value of the corresponding argument to
- ``theano.tensor.Tensor``.
- See the corresponding `documentation page`__ for more information on
- broadcasting in Theano.
- .. __: http://deeplearning.net/software/theano/tutorial/broadcasting.html
- Parameters
- ==========
- expr : sympy.core.expr.Expr
- SymPy expression to print.
- dtypes : dict
- Mapping from SymPy symbols to Theano datatypes to use when creating
- new Theano variables for those symbols. Corresponds to the ``dtype``
- argument to ``theano.tensor.Tensor``. Defaults to ``'floatX'``
- for symbols not included in the mapping.
- broadcastables : dict
- Mapping from SymPy symbols to the value of the ``broadcastable``
- argument to ``theano.tensor.Tensor`` to use when creating Theano
- variables for those symbols. Defaults to the empty tuple for symbols
- not included in the mapping (resulting in a scalar).
- Returns
- =======
- theano.gof.graph.Variable
- A variable corresponding to the expression's value in a Theano
- symbolic expression graph.
- """
- if dtypes is None:
- dtypes = {}
- if broadcastables is None:
- broadcastables = {}
- return self._print(expr, dtypes=dtypes, broadcastables=broadcastables)
- global_cache = {} # type: tDict[Any, Any]
- def theano_code(expr, cache=None, **kwargs):
- """
- Convert a SymPy expression into a Theano graph variable.
- .. deprecated:: 1.8
- ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
- Aesara. Use ``sympy.printing.aesaracode`` instead. See
- :ref:`theanocode-deprecated` for more information.
- Parameters
- ==========
- expr : sympy.core.expr.Expr
- SymPy expression object to convert.
- cache : dict
- Cached Theano variables (see :class:`TheanoPrinter.cache
- <TheanoPrinter>`). Defaults to the module-level global cache.
- dtypes : dict
- Passed to :meth:`.TheanoPrinter.doprint`.
- broadcastables : dict
- Passed to :meth:`.TheanoPrinter.doprint`.
- Returns
- =======
- theano.gof.graph.Variable
- A variable corresponding to the expression's value in a Theano symbolic
- expression graph.
- """
- sympy_deprecation_warning(
- """
- sympy.printing.theanocode is deprecated. Theano has been renamed to
- Aesara. Use sympy.printing.aesaracode instead.""",
- deprecated_since_version="1.8",
- active_deprecations_target='theanocode-deprecated')
- if not theano:
- raise ImportError("theano is required for theano_code")
- if cache is None:
- cache = global_cache
- return TheanoPrinter(cache=cache, settings={}).doprint(expr, **kwargs)
- def dim_handling(inputs, dim=None, dims=None, broadcastables=None):
- r"""
- Get value of ``broadcastables`` argument to :func:`.theano_code` from
- keyword arguments to :func:`.theano_function`.
- Included for backwards compatibility.
- Parameters
- ==========
- inputs
- Sequence of input symbols.
- dim : int
- Common number of dimensions for all inputs. Overrides other arguments
- if given.
- dims : dict
- Mapping from input symbols to number of dimensions. Overrides
- ``broadcastables`` argument if given.
- broadcastables : dict
- Explicit value of ``broadcastables`` argument to
- :meth:`.TheanoPrinter.doprint`. If not None function will return this value unchanged.
- Returns
- =======
- dict
- Dictionary mapping elements of ``inputs`` to their "broadcastable"
- values (tuple of ``bool``\ s).
- """
- if dim is not None:
- return {s: (False,) * dim for s in inputs}
- if dims is not None:
- maxdim = max(dims.values())
- return {
- s: (False,) * d + (True,) * (maxdim - d)
- for s, d in dims.items()
- }
- if broadcastables is not None:
- return broadcastables
- return {}
- @doctest_depends_on(modules=('theano',))
- def theano_function(inputs, outputs, scalar=False, *,
- dim=None, dims=None, broadcastables=None, **kwargs):
- """
- Create a Theano function from SymPy expressions.
- .. deprecated:: 1.8
- ``sympy.printing.theanocode`` is deprecated. Theano has been renamed to
- Aesara. Use ``sympy.printing.aesaracode`` instead. See
- :ref:`theanocode-deprecated` for more information.
- The inputs and outputs are converted to Theano variables using
- :func:`.theano_code` and then passed to ``theano.function``.
- Parameters
- ==========
- inputs
- Sequence of symbols which constitute the inputs of the function.
- outputs
- Sequence of expressions which constitute the outputs(s) of the
- function. The free symbols of each expression must be a subset of
- ``inputs``.
- scalar : bool
- Convert 0-dimensional arrays in output to scalars. This will return a
- Python wrapper function around the Theano function object.
- cache : dict
- Cached Theano variables (see :class:`TheanoPrinter.cache
- <TheanoPrinter>`). Defaults to the module-level global cache.
- dtypes : dict
- Passed to :meth:`.TheanoPrinter.doprint`.
- broadcastables : dict
- Passed to :meth:`.TheanoPrinter.doprint`.
- dims : dict
- Alternative to ``broadcastables`` argument. Mapping from elements of
- ``inputs`` to integers indicating the dimension of their associated
- arrays/tensors. Overrides ``broadcastables`` argument if given.
- dim : int
- Another alternative to the ``broadcastables`` argument. Common number of
- dimensions to use for all arrays/tensors.
- ``theano_function([x, y], [...], dim=2)`` is equivalent to using
- ``broadcastables={x: (False, False), y: (False, False)}``.
- Returns
- =======
- callable
- A callable object which takes values of ``inputs`` as positional
- arguments and returns an output array for each of the expressions
- in ``outputs``. If ``outputs`` is a single expression the function will
- return a Numpy array, if it is a list of multiple expressions the
- function will return a list of arrays. See description of the ``squeeze``
- argument above for the behavior when a single output is passed in a list.
- The returned object will either be an instance of
- ``theano.compile.function_module.Function`` or a Python wrapper
- function around one. In both cases, the returned value will have a
- ``theano_function`` attribute which points to the return value of
- ``theano.function``.
- Examples
- ========
- >>> from sympy.abc import x, y, z
- >>> from sympy.printing.theanocode import theano_function
- A simple function with one input and one output:
- >>> f1 = theano_function([x], [x**2 - 1], scalar=True)
- >>> f1(3)
- 8.0
- A function with multiple inputs and one output:
- >>> f2 = theano_function([x, y, z], [(x**z + y**z)**(1/z)], scalar=True)
- >>> f2(3, 4, 2)
- 5.0
- A function with multiple inputs and multiple outputs:
- >>> f3 = theano_function([x, y], [x**2 + y**2, x**2 - y**2], scalar=True)
- >>> f3(2, 3)
- [13.0, -5.0]
- See also
- ========
- dim_handling
- """
- sympy_deprecation_warning(
- """
- sympy.printing.theanocode is deprecated. Theano has been renamed to Aesara. Use sympy.printing.aesaracode instead""",
- deprecated_since_version="1.8",
- active_deprecations_target='theanocode-deprecated')
- if not theano:
- raise ImportError("theano is required for theano_function")
- # Pop off non-theano keyword args
- cache = kwargs.pop('cache', {})
- dtypes = kwargs.pop('dtypes', {})
- broadcastables = dim_handling(
- inputs, dim=dim, dims=dims, broadcastables=broadcastables,
- )
- # Print inputs/outputs
- code = partial(theano_code, cache=cache, dtypes=dtypes,
- broadcastables=broadcastables)
- tinputs = list(map(code, inputs))
- toutputs = list(map(code, outputs))
- #fix constant expressions as variables
- toutputs = [output if isinstance(output, theano.Variable) else tt.as_tensor_variable(output) for output in toutputs]
- if len(toutputs) == 1:
- toutputs = toutputs[0]
- # Compile theano func
- func = theano.function(tinputs, toutputs, **kwargs)
- is_0d = [len(o.variable.broadcastable) == 0 for o in func.outputs]
- # No wrapper required
- if not scalar or not any(is_0d):
- func.theano_function = func
- return func
- # Create wrapper to convert 0-dimensional outputs to scalars
- def wrapper(*args):
- out = func(*args)
- # out can be array(1.0) or [array(1.0), array(2.0)]
- if is_sequence(out):
- return [o[()] if is_0d[i] else o for i, o in enumerate(out)]
- else:
- return out[()]
- wrapper.__wrapped__ = func
- wrapper.__doc__ = func.__doc__
- wrapper.theano_function = func
- return wrapper
|