123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284 |
- from sympy.external.importtools import version_tuple
- from collections.abc import Iterable
- from sympy.core.mul import Mul
- from sympy.core.singleton import S
- from sympy.codegen.cfunctions import Sqrt
- from sympy.external import import_module
- from sympy.printing.precedence import PRECEDENCE
- from sympy.printing.pycode import AbstractPythonCodePrinter
- import sympy
- tensorflow = import_module('tensorflow')
- class TensorflowPrinter(AbstractPythonCodePrinter):
- """
- Tensorflow printer which handles vectorized piecewise functions,
- logical operators, max/min, and relational operators.
- """
- printmethod = "_tensorflowcode"
- mapping = {
- sympy.Abs: "tensorflow.math.abs",
- sympy.sign: "tensorflow.math.sign",
- # XXX May raise error for ints.
- sympy.ceiling: "tensorflow.math.ceil",
- sympy.floor: "tensorflow.math.floor",
- sympy.log: "tensorflow.math.log",
- sympy.exp: "tensorflow.math.exp",
- Sqrt: "tensorflow.math.sqrt",
- sympy.cos: "tensorflow.math.cos",
- sympy.acos: "tensorflow.math.acos",
- sympy.sin: "tensorflow.math.sin",
- sympy.asin: "tensorflow.math.asin",
- sympy.tan: "tensorflow.math.tan",
- sympy.atan: "tensorflow.math.atan",
- sympy.atan2: "tensorflow.math.atan2",
- # XXX Also may give NaN for complex results.
- sympy.cosh: "tensorflow.math.cosh",
- sympy.acosh: "tensorflow.math.acosh",
- sympy.sinh: "tensorflow.math.sinh",
- sympy.asinh: "tensorflow.math.asinh",
- sympy.tanh: "tensorflow.math.tanh",
- sympy.atanh: "tensorflow.math.atanh",
- sympy.re: "tensorflow.math.real",
- sympy.im: "tensorflow.math.imag",
- sympy.arg: "tensorflow.math.angle",
- # XXX May raise error for ints and complexes
- sympy.erf: "tensorflow.math.erf",
- sympy.loggamma: "tensorflow.math.lgamma",
- sympy.Eq: "tensorflow.math.equal",
- sympy.Ne: "tensorflow.math.not_equal",
- sympy.StrictGreaterThan: "tensorflow.math.greater",
- sympy.StrictLessThan: "tensorflow.math.less",
- sympy.LessThan: "tensorflow.math.less_equal",
- sympy.GreaterThan: "tensorflow.math.greater_equal",
- sympy.And: "tensorflow.math.logical_and",
- sympy.Or: "tensorflow.math.logical_or",
- sympy.Not: "tensorflow.math.logical_not",
- sympy.Max: "tensorflow.math.maximum",
- sympy.Min: "tensorflow.math.minimum",
- # Matrices
- sympy.MatAdd: "tensorflow.math.add",
- sympy.HadamardProduct: "tensorflow.math.multiply",
- sympy.Trace: "tensorflow.linalg.trace",
- # XXX May raise error for integer matrices.
- sympy.Determinant : "tensorflow.linalg.det",
- }
- _default_settings = dict(
- AbstractPythonCodePrinter._default_settings,
- tensorflow_version=None
- )
- def __init__(self, settings=None):
- super().__init__(settings)
- version = self._settings['tensorflow_version']
- if version is None and tensorflow:
- version = tensorflow.__version__
- self.tensorflow_version = version
- def _print_Function(self, expr):
- op = self.mapping.get(type(expr), None)
- if op is None:
- return super()._print_Basic(expr)
- children = [self._print(arg) for arg in expr.args]
- if len(children) == 1:
- return "%s(%s)" % (
- self._module_format(op),
- children[0]
- )
- else:
- return self._expand_fold_binary_op(op, children)
- _print_Expr = _print_Function
- _print_Application = _print_Function
- _print_MatrixExpr = _print_Function
- # TODO: a better class structure would avoid this mess:
- _print_Relational = _print_Function
- _print_Not = _print_Function
- _print_And = _print_Function
- _print_Or = _print_Function
- _print_HadamardProduct = _print_Function
- _print_Trace = _print_Function
- _print_Determinant = _print_Function
- def _print_Inverse(self, expr):
- op = self._module_format('tensorflow.linalg.inv')
- return "{}({})".format(op, self._print(expr.arg))
- def _print_Transpose(self, expr):
- version = self.tensorflow_version
- if version and version_tuple(version) < version_tuple('1.14'):
- op = self._module_format('tensorflow.matrix_transpose')
- else:
- op = self._module_format('tensorflow.linalg.matrix_transpose')
- return "{}({})".format(op, self._print(expr.arg))
- def _print_Derivative(self, expr):
- variables = expr.variables
- if any(isinstance(i, Iterable) for i in variables):
- raise NotImplementedError("derivation by multiple variables is not supported")
- def unfold(expr, args):
- if not args:
- return self._print(expr)
- return "%s(%s, %s)[0]" % (
- self._module_format("tensorflow.gradients"),
- unfold(expr, args[:-1]),
- self._print(args[-1]),
- )
- return unfold(expr.expr, variables)
- def _print_Piecewise(self, expr):
- version = self.tensorflow_version
- if version and version_tuple(version) < version_tuple('1.0'):
- tensorflow_piecewise = "tensorflow.select"
- else:
- tensorflow_piecewise = "tensorflow.where"
- from sympy.functions.elementary.piecewise import Piecewise
- e, cond = expr.args[0].args
- if len(expr.args) == 1:
- return '{}({}, {}, {})'.format(
- self._module_format(tensorflow_piecewise),
- self._print(cond),
- self._print(e),
- 0)
- return '{}({}, {}, {})'.format(
- self._module_format(tensorflow_piecewise),
- self._print(cond),
- self._print(e),
- self._print(Piecewise(*expr.args[1:])))
- def _print_Pow(self, expr):
- # XXX May raise error for
- # int**float or int**complex or float**complex
- base, exp = expr.args
- if expr.exp == S.Half:
- return "{}({})".format(
- self._module_format("tensorflow.math.sqrt"), self._print(base))
- return "{}({}, {})".format(
- self._module_format("tensorflow.math.pow"),
- self._print(base), self._print(exp))
- def _print_MatrixBase(self, expr):
- tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant"
- data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]"
- return "%s(%s)" % (
- self._module_format(tensorflow_f),
- data,
- )
- def _print_MatMul(self, expr):
- from sympy.matrices.expressions import MatrixExpr
- mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
- args = [arg for arg in expr.args if arg not in mat_args]
- if args:
- return "%s*%s" % (
- self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
- self._expand_fold_binary_op(
- "tensorflow.linalg.matmul", mat_args)
- )
- else:
- return self._expand_fold_binary_op(
- "tensorflow.linalg.matmul", mat_args)
- def _print_MatPow(self, expr):
- return self._expand_fold_binary_op(
- "tensorflow.linalg.matmul", [expr.base]*expr.exp)
- def _print_Assignment(self, expr):
- # TODO: is this necessary?
- return "%s = %s" % (
- self._print(expr.lhs),
- self._print(expr.rhs),
- )
- def _print_CodeBlock(self, expr):
- # TODO: is this necessary?
- ret = []
- for subexpr in expr.args:
- ret.append(self._print(subexpr))
- return "\n".join(ret)
- def _get_letter_generator_for_einsum(self):
- for i in range(97, 123):
- yield chr(i)
- for i in range(65, 91):
- yield chr(i)
- raise ValueError("out of letters")
- def _print_ArrayTensorProduct(self, expr):
- letters = self._get_letter_generator_for_einsum()
- contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks])
- return '%s("%s", %s)' % (
- self._module_format('tensorflow.linalg.einsum'),
- contraction_string,
- ", ".join([self._print(arg) for arg in expr.args])
- )
- def _print_ArrayContraction(self, expr):
- from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
- base = expr.expr
- contraction_indices = expr.contraction_indices
- contraction_string, letters_free, letters_dum = self._get_einsum_string(base.subranks, contraction_indices)
- if not contraction_indices:
- return self._print(base)
- if isinstance(base, ArrayTensorProduct):
- elems = ["%s" % (self._print(arg)) for arg in base.args]
- return "%s(\"%s\", %s)" % (
- self._module_format("tensorflow.linalg.einsum"),
- contraction_string,
- ", ".join(elems)
- )
- raise NotImplementedError()
- def _print_ArrayDiagonal(self, expr):
- from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
- diagonal_indices = list(expr.diagonal_indices)
- if len(diagonal_indices) > 1:
- # TODO: this should be handled in sympy.codegen.array_utils,
- # possibly by creating the possibility of unfolding the
- # ArrayDiagonal object into nested ones. Same reasoning for
- # the array contraction.
- raise NotImplementedError
- if len(diagonal_indices[0]) != 2:
- raise NotImplementedError
- if isinstance(expr.expr, ArrayTensorProduct):
- subranks = expr.expr.subranks
- elems = expr.expr.args
- else:
- subranks = expr.subranks
- elems = [expr.expr]
- diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices)
- elems = [self._print(i) for i in elems]
- return '%s("%s", %s)' % (
- self._module_format("tensorflow.linalg.einsum"),
- "{}->{}{}".format(diagonal_string, "".join(letters_free), "".join(letters_dum)),
- ", ".join(elems)
- )
- def _print_PermuteDims(self, expr):
- return "%s(%s, %s)" % (
- self._module_format("tensorflow.transpose"),
- self._print(expr.expr),
- self._print(expr.permutation.array_form),
- )
- def _print_ArrayAdd(self, expr):
- return self._expand_fold_binary_op('tensorflow.math.add', expr.args)
- def tensorflow_code(expr, **settings):
- printer = TensorflowPrinter(settings)
- return printer.doprint(expr)
|