pycode.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639
  1. """
  2. Python code printers
  3. This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
  4. """
  5. from collections import defaultdict
  6. from itertools import chain
  7. from sympy.core import S
  8. from .precedence import precedence
  9. from .codeprinter import CodePrinter
  10. _kw = {
  11. 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
  12. 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
  13. 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
  14. 'with', 'yield', 'None', 'False', 'nonlocal', 'True'
  15. }
  16. _known_functions = {
  17. 'Abs': 'abs',
  18. }
  19. _known_functions_math = {
  20. 'acos': 'acos',
  21. 'acosh': 'acosh',
  22. 'asin': 'asin',
  23. 'asinh': 'asinh',
  24. 'atan': 'atan',
  25. 'atan2': 'atan2',
  26. 'atanh': 'atanh',
  27. 'ceiling': 'ceil',
  28. 'cos': 'cos',
  29. 'cosh': 'cosh',
  30. 'erf': 'erf',
  31. 'erfc': 'erfc',
  32. 'exp': 'exp',
  33. 'expm1': 'expm1',
  34. 'factorial': 'factorial',
  35. 'floor': 'floor',
  36. 'gamma': 'gamma',
  37. 'hypot': 'hypot',
  38. 'loggamma': 'lgamma',
  39. 'log': 'log',
  40. 'ln': 'log',
  41. 'log10': 'log10',
  42. 'log1p': 'log1p',
  43. 'log2': 'log2',
  44. 'sin': 'sin',
  45. 'sinh': 'sinh',
  46. 'Sqrt': 'sqrt',
  47. 'tan': 'tan',
  48. 'tanh': 'tanh'
  49. } # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
  50. # radians trunc fmod fsum gcd degrees fabs]
  51. _known_constants_math = {
  52. 'Exp1': 'e',
  53. 'Pi': 'pi',
  54. 'E': 'e',
  55. 'Infinity': 'inf',
  56. 'NaN': 'nan',
  57. 'ComplexInfinity': 'nan'
  58. }
  59. def _print_known_func(self, expr):
  60. known = self.known_functions[expr.__class__.__name__]
  61. return '{name}({args})'.format(name=self._module_format(known),
  62. args=', '.join(map(lambda arg: self._print(arg), expr.args)))
  63. def _print_known_const(self, expr):
  64. known = self.known_constants[expr.__class__.__name__]
  65. return self._module_format(known)
  66. class AbstractPythonCodePrinter(CodePrinter):
  67. printmethod = "_pythoncode"
  68. language = "Python"
  69. reserved_words = _kw
  70. modules = None # initialized to a set in __init__
  71. tab = ' '
  72. _kf = dict(chain(
  73. _known_functions.items(),
  74. [(k, 'math.' + v) for k, v in _known_functions_math.items()]
  75. ))
  76. _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
  77. _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
  78. _default_settings = dict(
  79. CodePrinter._default_settings,
  80. user_functions={},
  81. precision=17,
  82. inline=True,
  83. fully_qualified_modules=True,
  84. contract=False,
  85. standard='python3',
  86. )
  87. def __init__(self, settings=None):
  88. super().__init__(settings)
  89. # Python standard handler
  90. std = self._settings['standard']
  91. if std is None:
  92. import sys
  93. std = 'python{}'.format(sys.version_info.major)
  94. if std != 'python3':
  95. raise ValueError('Only Python 3 is supported.')
  96. self.standard = std
  97. self.module_imports = defaultdict(set)
  98. # Known functions and constants handler
  99. self.known_functions = dict(self._kf, **(settings or {}).get(
  100. 'user_functions', {}))
  101. self.known_constants = dict(self._kc, **(settings or {}).get(
  102. 'user_constants', {}))
  103. def _declare_number_const(self, name, value):
  104. return "%s = %s" % (name, value)
  105. def _module_format(self, fqn, register=True):
  106. parts = fqn.split('.')
  107. if register and len(parts) > 1:
  108. self.module_imports['.'.join(parts[:-1])].add(parts[-1])
  109. if self._settings['fully_qualified_modules']:
  110. return fqn
  111. else:
  112. return fqn.split('(')[0].split('[')[0].split('.')[-1]
  113. def _format_code(self, lines):
  114. return lines
  115. def _get_statement(self, codestring):
  116. return "{}".format(codestring)
  117. def _get_comment(self, text):
  118. return " # {}".format(text)
  119. def _expand_fold_binary_op(self, op, args):
  120. """
  121. This method expands a fold on binary operations.
  122. ``functools.reduce`` is an example of a folded operation.
  123. For example, the expression
  124. `A + B + C + D`
  125. is folded into
  126. `((A + B) + C) + D`
  127. """
  128. if len(args) == 1:
  129. return self._print(args[0])
  130. else:
  131. return "%s(%s, %s)" % (
  132. self._module_format(op),
  133. self._expand_fold_binary_op(op, args[:-1]),
  134. self._print(args[-1]),
  135. )
  136. def _expand_reduce_binary_op(self, op, args):
  137. """
  138. This method expands a reductin on binary operations.
  139. Notice: this is NOT the same as ``functools.reduce``.
  140. For example, the expression
  141. `A + B + C + D`
  142. is reduced into:
  143. `(A + B) + (C + D)`
  144. """
  145. if len(args) == 1:
  146. return self._print(args[0])
  147. else:
  148. N = len(args)
  149. Nhalf = N // 2
  150. return "%s(%s, %s)" % (
  151. self._module_format(op),
  152. self._expand_reduce_binary_op(args[:Nhalf]),
  153. self._expand_reduce_binary_op(args[Nhalf:]),
  154. )
  155. def _get_einsum_string(self, subranks, contraction_indices):
  156. letters = self._get_letter_generator_for_einsum()
  157. contraction_string = ""
  158. counter = 0
  159. d = {j: min(i) for i in contraction_indices for j in i}
  160. indices = []
  161. for rank_arg in subranks:
  162. lindices = []
  163. for i in range(rank_arg):
  164. if counter in d:
  165. lindices.append(d[counter])
  166. else:
  167. lindices.append(counter)
  168. counter += 1
  169. indices.append(lindices)
  170. mapping = {}
  171. letters_free = []
  172. letters_dum = []
  173. for i in indices:
  174. for j in i:
  175. if j not in mapping:
  176. l = next(letters)
  177. mapping[j] = l
  178. else:
  179. l = mapping[j]
  180. contraction_string += l
  181. if j in d:
  182. if l not in letters_dum:
  183. letters_dum.append(l)
  184. else:
  185. letters_free.append(l)
  186. contraction_string += ","
  187. contraction_string = contraction_string[:-1]
  188. return contraction_string, letters_free, letters_dum
  189. def _print_NaN(self, expr):
  190. return "float('nan')"
  191. def _print_Infinity(self, expr):
  192. return "float('inf')"
  193. def _print_NegativeInfinity(self, expr):
  194. return "float('-inf')"
  195. def _print_ComplexInfinity(self, expr):
  196. return self._print_NaN(expr)
  197. def _print_Mod(self, expr):
  198. PREC = precedence(expr)
  199. return ('{} % {}'.format(*map(lambda x: self.parenthesize(x, PREC), expr.args)))
  200. def _print_Piecewise(self, expr):
  201. result = []
  202. i = 0
  203. for arg in expr.args:
  204. e = arg.expr
  205. c = arg.cond
  206. if i == 0:
  207. result.append('(')
  208. result.append('(')
  209. result.append(self._print(e))
  210. result.append(')')
  211. result.append(' if ')
  212. result.append(self._print(c))
  213. result.append(' else ')
  214. i += 1
  215. result = result[:-1]
  216. if result[-1] == 'True':
  217. result = result[:-2]
  218. result.append(')')
  219. else:
  220. result.append(' else None)')
  221. return ''.join(result)
  222. def _print_Relational(self, expr):
  223. "Relational printer for Equality and Unequality"
  224. op = {
  225. '==' :'equal',
  226. '!=' :'not_equal',
  227. '<' :'less',
  228. '<=' :'less_equal',
  229. '>' :'greater',
  230. '>=' :'greater_equal',
  231. }
  232. if expr.rel_op in op:
  233. lhs = self._print(expr.lhs)
  234. rhs = self._print(expr.rhs)
  235. return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
  236. return super()._print_Relational(expr)
  237. def _print_ITE(self, expr):
  238. from sympy.functions.elementary.piecewise import Piecewise
  239. return self._print(expr.rewrite(Piecewise))
  240. def _print_Sum(self, expr):
  241. loops = (
  242. 'for {i} in range({a}, {b}+1)'.format(
  243. i=self._print(i),
  244. a=self._print(a),
  245. b=self._print(b))
  246. for i, a, b in expr.limits)
  247. return '(builtins.sum({function} {loops}))'.format(
  248. function=self._print(expr.function),
  249. loops=' '.join(loops))
  250. def _print_ImaginaryUnit(self, expr):
  251. return '1j'
  252. def _print_KroneckerDelta(self, expr):
  253. a, b = expr.args
  254. return '(1 if {a} == {b} else 0)'.format(
  255. a = self._print(a),
  256. b = self._print(b)
  257. )
  258. def _print_MatrixBase(self, expr):
  259. name = expr.__class__.__name__
  260. func = self.known_functions.get(name, name)
  261. return "%s(%s)" % (func, self._print(expr.tolist()))
  262. _print_SparseRepMatrix = \
  263. _print_MutableSparseMatrix = \
  264. _print_ImmutableSparseMatrix = \
  265. _print_Matrix = \
  266. _print_DenseMatrix = \
  267. _print_MutableDenseMatrix = \
  268. _print_ImmutableMatrix = \
  269. _print_ImmutableDenseMatrix = \
  270. lambda self, expr: self._print_MatrixBase(expr)
  271. def _indent_codestring(self, codestring):
  272. return '\n'.join([self.tab + line for line in codestring.split('\n')])
  273. def _print_FunctionDefinition(self, fd):
  274. body = '\n'.join(map(lambda arg: self._print(arg), fd.body))
  275. return "def {name}({parameters}):\n{body}".format(
  276. name=self._print(fd.name),
  277. parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
  278. body=self._indent_codestring(body)
  279. )
  280. def _print_While(self, whl):
  281. body = '\n'.join(map(lambda arg: self._print(arg), whl.body))
  282. return "while {cond}:\n{body}".format(
  283. cond=self._print(whl.condition),
  284. body=self._indent_codestring(body)
  285. )
  286. def _print_Declaration(self, decl):
  287. return '%s = %s' % (
  288. self._print(decl.variable.symbol),
  289. self._print(decl.variable.value)
  290. )
  291. def _print_Return(self, ret):
  292. arg, = ret.args
  293. return 'return %s' % self._print(arg)
  294. def _print_Print(self, prnt):
  295. print_args = ', '.join(map(lambda arg: self._print(arg), prnt.print_args))
  296. if prnt.format_string != None: # Must be '!= None', cannot be 'is not None'
  297. print_args = '{} % ({})'.format(
  298. self._print(prnt.format_string), print_args)
  299. if prnt.file != None: # Must be '!= None', cannot be 'is not None'
  300. print_args += ', file=%s' % self._print(prnt.file)
  301. return 'print(%s)' % print_args
  302. def _print_Stream(self, strm):
  303. if str(strm.name) == 'stdout':
  304. return self._module_format('sys.stdout')
  305. elif str(strm.name) == 'stderr':
  306. return self._module_format('sys.stderr')
  307. else:
  308. return self._print(strm.name)
  309. def _print_NoneToken(self, arg):
  310. return 'None'
  311. def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
  312. """Printing helper function for ``Pow``
  313. Notes
  314. =====
  315. This only preprocesses the ``sqrt`` as math formatter
  316. Examples
  317. ========
  318. >>> from sympy import sqrt
  319. >>> from sympy.printing.pycode import PythonCodePrinter
  320. >>> from sympy.abc import x
  321. Python code printer automatically looks up ``math.sqrt``.
  322. >>> printer = PythonCodePrinter()
  323. >>> printer._hprint_Pow(sqrt(x), rational=True)
  324. 'x**(1/2)'
  325. >>> printer._hprint_Pow(sqrt(x), rational=False)
  326. 'math.sqrt(x)'
  327. >>> printer._hprint_Pow(1/sqrt(x), rational=True)
  328. 'x**(-1/2)'
  329. >>> printer._hprint_Pow(1/sqrt(x), rational=False)
  330. '1/math.sqrt(x)'
  331. Using sqrt from numpy or mpmath
  332. >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
  333. 'numpy.sqrt(x)'
  334. >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
  335. 'mpmath.sqrt(x)'
  336. See Also
  337. ========
  338. sympy.printing.str.StrPrinter._print_Pow
  339. """
  340. PREC = precedence(expr)
  341. if expr.exp == S.Half and not rational:
  342. func = self._module_format(sqrt)
  343. arg = self._print(expr.base)
  344. return '{func}({arg})'.format(func=func, arg=arg)
  345. if expr.is_commutative:
  346. if -expr.exp is S.Half and not rational:
  347. func = self._module_format(sqrt)
  348. num = self._print(S.One)
  349. arg = self._print(expr.base)
  350. return "{num}/{func}({arg})".format(
  351. num=num, func=func, arg=arg)
  352. base_str = self.parenthesize(expr.base, PREC, strict=False)
  353. exp_str = self.parenthesize(expr.exp, PREC, strict=False)
  354. return "{}**{}".format(base_str, exp_str)
  355. class PythonCodePrinter(AbstractPythonCodePrinter):
  356. def _print_sign(self, e):
  357. return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
  358. f=self._module_format('math.copysign'), e=self._print(e.args[0]))
  359. def _print_Not(self, expr):
  360. PREC = precedence(expr)
  361. return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
  362. def _print_Indexed(self, expr):
  363. base = expr.args[0]
  364. index = expr.args[1:]
  365. return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
  366. def _print_Pow(self, expr, rational=False):
  367. return self._hprint_Pow(expr, rational=rational)
  368. def _print_Rational(self, expr):
  369. return '{}/{}'.format(expr.p, expr.q)
  370. def _print_Half(self, expr):
  371. return self._print_Rational(expr)
  372. def _print_frac(self, expr):
  373. from sympy.core.mod import Mod
  374. return self._print_Mod(Mod(expr.args[0], 1))
  375. def _print_Symbol(self, expr):
  376. name = super()._print_Symbol(expr)
  377. if name in self.reserved_words:
  378. if self._settings['error_on_reserved']:
  379. msg = ('This expression includes the symbol "{}" which is a '
  380. 'reserved keyword in this language.')
  381. raise ValueError(msg.format(name))
  382. return name + self._settings['reserved_word_suffix']
  383. elif '{' in name: # Remove curly braces from subscripted variables
  384. return name.replace('{', '').replace('}', '')
  385. else:
  386. return name
  387. _print_lowergamma = CodePrinter._print_not_supported
  388. _print_uppergamma = CodePrinter._print_not_supported
  389. _print_fresnelc = CodePrinter._print_not_supported
  390. _print_fresnels = CodePrinter._print_not_supported
  391. for k in PythonCodePrinter._kf:
  392. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
  393. for k in _known_constants_math:
  394. setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
  395. def pycode(expr, **settings):
  396. """ Converts an expr to a string of Python code
  397. Parameters
  398. ==========
  399. expr : Expr
  400. A SymPy expression.
  401. fully_qualified_modules : bool
  402. Whether or not to write out full module names of functions
  403. (``math.sin`` vs. ``sin``). default: ``True``.
  404. standard : str or None, optional
  405. Only 'python3' (default) is supported.
  406. This parameter may be removed in the future.
  407. Examples
  408. ========
  409. >>> from sympy import pycode, tan, Symbol
  410. >>> pycode(tan(Symbol('x')) + 1)
  411. 'math.tan(x) + 1'
  412. """
  413. return PythonCodePrinter(settings).doprint(expr)
  414. _not_in_mpmath = 'log1p log2'.split()
  415. _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
  416. _known_functions_mpmath = dict(_in_mpmath, **{
  417. 'beta': 'beta',
  418. 'frac': 'frac',
  419. 'fresnelc': 'fresnelc',
  420. 'fresnels': 'fresnels',
  421. 'sign': 'sign',
  422. 'loggamma': 'loggamma',
  423. 'hyper': 'hyper',
  424. 'meijerg': 'meijerg',
  425. 'besselj': 'besselj',
  426. 'bessely': 'bessely',
  427. 'besseli': 'besseli',
  428. 'besselk': 'besselk',
  429. })
  430. _known_constants_mpmath = {
  431. 'Exp1': 'e',
  432. 'Pi': 'pi',
  433. 'GoldenRatio': 'phi',
  434. 'EulerGamma': 'euler',
  435. 'Catalan': 'catalan',
  436. 'NaN': 'nan',
  437. 'Infinity': 'inf',
  438. 'NegativeInfinity': 'ninf'
  439. }
  440. def _unpack_integral_limits(integral_expr):
  441. """ helper function for _print_Integral that
  442. - accepts an Integral expression
  443. - returns a tuple of
  444. - a list variables of integration
  445. - a list of tuples of the upper and lower limits of integration
  446. """
  447. integration_vars = []
  448. limits = []
  449. for integration_range in integral_expr.limits:
  450. if len(integration_range) == 3:
  451. integration_var, lower_limit, upper_limit = integration_range
  452. else:
  453. raise NotImplementedError("Only definite integrals are supported")
  454. integration_vars.append(integration_var)
  455. limits.append((lower_limit, upper_limit))
  456. return integration_vars, limits
  457. class MpmathPrinter(PythonCodePrinter):
  458. """
  459. Lambda printer for mpmath which maintains precision for floats
  460. """
  461. printmethod = "_mpmathcode"
  462. language = "Python with mpmath"
  463. _kf = dict(chain(
  464. _known_functions.items(),
  465. [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
  466. ))
  467. _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
  468. def _print_Float(self, e):
  469. # XXX: This does not handle setting mpmath.mp.dps. It is assumed that
  470. # the caller of the lambdified function will have set it to sufficient
  471. # precision to match the Floats in the expression.
  472. # Remove 'mpz' if gmpy is installed.
  473. args = str(tuple(map(int, e._mpf_)))
  474. return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
  475. def _print_Rational(self, e):
  476. return "{func}({p})/{func}({q})".format(
  477. func=self._module_format('mpmath.mpf'),
  478. q=self._print(e.q),
  479. p=self._print(e.p)
  480. )
  481. def _print_Half(self, e):
  482. return self._print_Rational(e)
  483. def _print_uppergamma(self, e):
  484. return "{}({}, {}, {})".format(
  485. self._module_format('mpmath.gammainc'),
  486. self._print(e.args[0]),
  487. self._print(e.args[1]),
  488. self._module_format('mpmath.inf'))
  489. def _print_lowergamma(self, e):
  490. return "{}({}, 0, {})".format(
  491. self._module_format('mpmath.gammainc'),
  492. self._print(e.args[0]),
  493. self._print(e.args[1]))
  494. def _print_log2(self, e):
  495. return '{0}({1})/{0}(2)'.format(
  496. self._module_format('mpmath.log'), self._print(e.args[0]))
  497. def _print_log1p(self, e):
  498. return '{}({}+1)'.format(
  499. self._module_format('mpmath.log'), self._print(e.args[0]))
  500. def _print_Pow(self, expr, rational=False):
  501. return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
  502. def _print_Integral(self, e):
  503. integration_vars, limits = _unpack_integral_limits(e)
  504. return "{}(lambda {}: {}, {})".format(
  505. self._module_format("mpmath.quad"),
  506. ", ".join(map(self._print, integration_vars)),
  507. self._print(e.args[0]),
  508. ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  509. for k in MpmathPrinter._kf:
  510. setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
  511. for k in _known_constants_mpmath:
  512. setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
  513. class SymPyPrinter(AbstractPythonCodePrinter):
  514. language = "Python with SymPy"
  515. def _print_Function(self, expr):
  516. mod = expr.func.__module__ or ''
  517. return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
  518. ', '.join(map(lambda arg: self._print(arg), expr.args)))
  519. def _print_Pow(self, expr, rational=False):
  520. return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')