tensorflow.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. from sympy.external.importtools import version_tuple
  2. from collections.abc import Iterable
  3. from sympy.core.mul import Mul
  4. from sympy.core.singleton import S
  5. from sympy.codegen.cfunctions import Sqrt
  6. from sympy.external import import_module
  7. from sympy.printing.precedence import PRECEDENCE
  8. from sympy.printing.pycode import AbstractPythonCodePrinter
  9. import sympy
  10. tensorflow = import_module('tensorflow')
  11. class TensorflowPrinter(AbstractPythonCodePrinter):
  12. """
  13. Tensorflow printer which handles vectorized piecewise functions,
  14. logical operators, max/min, and relational operators.
  15. """
  16. printmethod = "_tensorflowcode"
  17. mapping = {
  18. sympy.Abs: "tensorflow.math.abs",
  19. sympy.sign: "tensorflow.math.sign",
  20. # XXX May raise error for ints.
  21. sympy.ceiling: "tensorflow.math.ceil",
  22. sympy.floor: "tensorflow.math.floor",
  23. sympy.log: "tensorflow.math.log",
  24. sympy.exp: "tensorflow.math.exp",
  25. Sqrt: "tensorflow.math.sqrt",
  26. sympy.cos: "tensorflow.math.cos",
  27. sympy.acos: "tensorflow.math.acos",
  28. sympy.sin: "tensorflow.math.sin",
  29. sympy.asin: "tensorflow.math.asin",
  30. sympy.tan: "tensorflow.math.tan",
  31. sympy.atan: "tensorflow.math.atan",
  32. sympy.atan2: "tensorflow.math.atan2",
  33. # XXX Also may give NaN for complex results.
  34. sympy.cosh: "tensorflow.math.cosh",
  35. sympy.acosh: "tensorflow.math.acosh",
  36. sympy.sinh: "tensorflow.math.sinh",
  37. sympy.asinh: "tensorflow.math.asinh",
  38. sympy.tanh: "tensorflow.math.tanh",
  39. sympy.atanh: "tensorflow.math.atanh",
  40. sympy.re: "tensorflow.math.real",
  41. sympy.im: "tensorflow.math.imag",
  42. sympy.arg: "tensorflow.math.angle",
  43. # XXX May raise error for ints and complexes
  44. sympy.erf: "tensorflow.math.erf",
  45. sympy.loggamma: "tensorflow.math.lgamma",
  46. sympy.Eq: "tensorflow.math.equal",
  47. sympy.Ne: "tensorflow.math.not_equal",
  48. sympy.StrictGreaterThan: "tensorflow.math.greater",
  49. sympy.StrictLessThan: "tensorflow.math.less",
  50. sympy.LessThan: "tensorflow.math.less_equal",
  51. sympy.GreaterThan: "tensorflow.math.greater_equal",
  52. sympy.And: "tensorflow.math.logical_and",
  53. sympy.Or: "tensorflow.math.logical_or",
  54. sympy.Not: "tensorflow.math.logical_not",
  55. sympy.Max: "tensorflow.math.maximum",
  56. sympy.Min: "tensorflow.math.minimum",
  57. # Matrices
  58. sympy.MatAdd: "tensorflow.math.add",
  59. sympy.HadamardProduct: "tensorflow.math.multiply",
  60. sympy.Trace: "tensorflow.linalg.trace",
  61. # XXX May raise error for integer matrices.
  62. sympy.Determinant : "tensorflow.linalg.det",
  63. }
  64. _default_settings = dict(
  65. AbstractPythonCodePrinter._default_settings,
  66. tensorflow_version=None
  67. )
  68. def __init__(self, settings=None):
  69. super().__init__(settings)
  70. version = self._settings['tensorflow_version']
  71. if version is None and tensorflow:
  72. version = tensorflow.__version__
  73. self.tensorflow_version = version
  74. def _print_Function(self, expr):
  75. op = self.mapping.get(type(expr), None)
  76. if op is None:
  77. return super()._print_Basic(expr)
  78. children = [self._print(arg) for arg in expr.args]
  79. if len(children) == 1:
  80. return "%s(%s)" % (
  81. self._module_format(op),
  82. children[0]
  83. )
  84. else:
  85. return self._expand_fold_binary_op(op, children)
  86. _print_Expr = _print_Function
  87. _print_Application = _print_Function
  88. _print_MatrixExpr = _print_Function
  89. # TODO: a better class structure would avoid this mess:
  90. _print_Relational = _print_Function
  91. _print_Not = _print_Function
  92. _print_And = _print_Function
  93. _print_Or = _print_Function
  94. _print_HadamardProduct = _print_Function
  95. _print_Trace = _print_Function
  96. _print_Determinant = _print_Function
  97. def _print_Inverse(self, expr):
  98. op = self._module_format('tensorflow.linalg.inv')
  99. return "{}({})".format(op, self._print(expr.arg))
  100. def _print_Transpose(self, expr):
  101. version = self.tensorflow_version
  102. if version and version_tuple(version) < version_tuple('1.14'):
  103. op = self._module_format('tensorflow.matrix_transpose')
  104. else:
  105. op = self._module_format('tensorflow.linalg.matrix_transpose')
  106. return "{}({})".format(op, self._print(expr.arg))
  107. def _print_Derivative(self, expr):
  108. variables = expr.variables
  109. if any(isinstance(i, Iterable) for i in variables):
  110. raise NotImplementedError("derivation by multiple variables is not supported")
  111. def unfold(expr, args):
  112. if not args:
  113. return self._print(expr)
  114. return "%s(%s, %s)[0]" % (
  115. self._module_format("tensorflow.gradients"),
  116. unfold(expr, args[:-1]),
  117. self._print(args[-1]),
  118. )
  119. return unfold(expr.expr, variables)
  120. def _print_Piecewise(self, expr):
  121. version = self.tensorflow_version
  122. if version and version_tuple(version) < version_tuple('1.0'):
  123. tensorflow_piecewise = "tensorflow.select"
  124. else:
  125. tensorflow_piecewise = "tensorflow.where"
  126. from sympy.functions.elementary.piecewise import Piecewise
  127. e, cond = expr.args[0].args
  128. if len(expr.args) == 1:
  129. return '{}({}, {}, {})'.format(
  130. self._module_format(tensorflow_piecewise),
  131. self._print(cond),
  132. self._print(e),
  133. 0)
  134. return '{}({}, {}, {})'.format(
  135. self._module_format(tensorflow_piecewise),
  136. self._print(cond),
  137. self._print(e),
  138. self._print(Piecewise(*expr.args[1:])))
  139. def _print_Pow(self, expr):
  140. # XXX May raise error for
  141. # int**float or int**complex or float**complex
  142. base, exp = expr.args
  143. if expr.exp == S.Half:
  144. return "{}({})".format(
  145. self._module_format("tensorflow.math.sqrt"), self._print(base))
  146. return "{}({}, {})".format(
  147. self._module_format("tensorflow.math.pow"),
  148. self._print(base), self._print(exp))
  149. def _print_MatrixBase(self, expr):
  150. tensorflow_f = "tensorflow.Variable" if expr.free_symbols else "tensorflow.constant"
  151. data = "["+", ".join(["["+", ".join([self._print(j) for j in i])+"]" for i in expr.tolist()])+"]"
  152. return "%s(%s)" % (
  153. self._module_format(tensorflow_f),
  154. data,
  155. )
  156. def _print_MatMul(self, expr):
  157. from sympy.matrices.expressions import MatrixExpr
  158. mat_args = [arg for arg in expr.args if isinstance(arg, MatrixExpr)]
  159. args = [arg for arg in expr.args if arg not in mat_args]
  160. if args:
  161. return "%s*%s" % (
  162. self.parenthesize(Mul.fromiter(args), PRECEDENCE["Mul"]),
  163. self._expand_fold_binary_op(
  164. "tensorflow.linalg.matmul", mat_args)
  165. )
  166. else:
  167. return self._expand_fold_binary_op(
  168. "tensorflow.linalg.matmul", mat_args)
  169. def _print_MatPow(self, expr):
  170. return self._expand_fold_binary_op(
  171. "tensorflow.linalg.matmul", [expr.base]*expr.exp)
  172. def _print_Assignment(self, expr):
  173. # TODO: is this necessary?
  174. return "%s = %s" % (
  175. self._print(expr.lhs),
  176. self._print(expr.rhs),
  177. )
  178. def _print_CodeBlock(self, expr):
  179. # TODO: is this necessary?
  180. ret = []
  181. for subexpr in expr.args:
  182. ret.append(self._print(subexpr))
  183. return "\n".join(ret)
  184. def _get_letter_generator_for_einsum(self):
  185. for i in range(97, 123):
  186. yield chr(i)
  187. for i in range(65, 91):
  188. yield chr(i)
  189. raise ValueError("out of letters")
  190. def _print_ArrayTensorProduct(self, expr):
  191. letters = self._get_letter_generator_for_einsum()
  192. contraction_string = ",".join(["".join([next(letters) for j in range(i)]) for i in expr.subranks])
  193. return '%s("%s", %s)' % (
  194. self._module_format('tensorflow.linalg.einsum'),
  195. contraction_string,
  196. ", ".join([self._print(arg) for arg in expr.args])
  197. )
  198. def _print_ArrayContraction(self, expr):
  199. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  200. base = expr.expr
  201. contraction_indices = expr.contraction_indices
  202. contraction_string, letters_free, letters_dum = self._get_einsum_string(base.subranks, contraction_indices)
  203. if not contraction_indices:
  204. return self._print(base)
  205. if isinstance(base, ArrayTensorProduct):
  206. elems = ["%s" % (self._print(arg)) for arg in base.args]
  207. return "%s(\"%s\", %s)" % (
  208. self._module_format("tensorflow.linalg.einsum"),
  209. contraction_string,
  210. ", ".join(elems)
  211. )
  212. raise NotImplementedError()
  213. def _print_ArrayDiagonal(self, expr):
  214. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  215. diagonal_indices = list(expr.diagonal_indices)
  216. if len(diagonal_indices) > 1:
  217. # TODO: this should be handled in sympy.codegen.array_utils,
  218. # possibly by creating the possibility of unfolding the
  219. # ArrayDiagonal object into nested ones. Same reasoning for
  220. # the array contraction.
  221. raise NotImplementedError
  222. if len(diagonal_indices[0]) != 2:
  223. raise NotImplementedError
  224. if isinstance(expr.expr, ArrayTensorProduct):
  225. subranks = expr.expr.subranks
  226. elems = expr.expr.args
  227. else:
  228. subranks = expr.subranks
  229. elems = [expr.expr]
  230. diagonal_string, letters_free, letters_dum = self._get_einsum_string(subranks, diagonal_indices)
  231. elems = [self._print(i) for i in elems]
  232. return '%s("%s", %s)' % (
  233. self._module_format("tensorflow.linalg.einsum"),
  234. "{}->{}{}".format(diagonal_string, "".join(letters_free), "".join(letters_dum)),
  235. ", ".join(elems)
  236. )
  237. def _print_PermuteDims(self, expr):
  238. return "%s(%s, %s)" % (
  239. self._module_format("tensorflow.transpose"),
  240. self._print(expr.expr),
  241. self._print(expr.permutation.array_form),
  242. )
  243. def _print_ArrayAdd(self, expr):
  244. return self._expand_fold_binary_op('tensorflow.math.add', expr.args)
  245. def tensorflow_code(expr, **settings):
  246. printer = TensorflowPrinter(settings)
  247. return printer.doprint(expr)