123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639 |
- """
- Python code printers
- This module contains Python code printers for plain Python as well as NumPy & SciPy enabled code.
- """
- from collections import defaultdict
- from itertools import chain
- from sympy.core import S
- from .precedence import precedence
- from .codeprinter import CodePrinter
- _kw = {
- 'and', 'as', 'assert', 'break', 'class', 'continue', 'def', 'del', 'elif',
- 'else', 'except', 'finally', 'for', 'from', 'global', 'if', 'import', 'in',
- 'is', 'lambda', 'not', 'or', 'pass', 'raise', 'return', 'try', 'while',
- 'with', 'yield', 'None', 'False', 'nonlocal', 'True'
- }
- _known_functions = {
- 'Abs': 'abs',
- }
- _known_functions_math = {
- 'acos': 'acos',
- 'acosh': 'acosh',
- 'asin': 'asin',
- 'asinh': 'asinh',
- 'atan': 'atan',
- 'atan2': 'atan2',
- 'atanh': 'atanh',
- 'ceiling': 'ceil',
- 'cos': 'cos',
- 'cosh': 'cosh',
- 'erf': 'erf',
- 'erfc': 'erfc',
- 'exp': 'exp',
- 'expm1': 'expm1',
- 'factorial': 'factorial',
- 'floor': 'floor',
- 'gamma': 'gamma',
- 'hypot': 'hypot',
- 'loggamma': 'lgamma',
- 'log': 'log',
- 'ln': 'log',
- 'log10': 'log10',
- 'log1p': 'log1p',
- 'log2': 'log2',
- 'sin': 'sin',
- 'sinh': 'sinh',
- 'Sqrt': 'sqrt',
- 'tan': 'tan',
- 'tanh': 'tanh'
- } # Not used from ``math``: [copysign isclose isfinite isinf isnan ldexp frexp pow modf
- # radians trunc fmod fsum gcd degrees fabs]
- _known_constants_math = {
- 'Exp1': 'e',
- 'Pi': 'pi',
- 'E': 'e',
- 'Infinity': 'inf',
- 'NaN': 'nan',
- 'ComplexInfinity': 'nan'
- }
- def _print_known_func(self, expr):
- known = self.known_functions[expr.__class__.__name__]
- return '{name}({args})'.format(name=self._module_format(known),
- args=', '.join(map(lambda arg: self._print(arg), expr.args)))
- def _print_known_const(self, expr):
- known = self.known_constants[expr.__class__.__name__]
- return self._module_format(known)
- class AbstractPythonCodePrinter(CodePrinter):
- printmethod = "_pythoncode"
- language = "Python"
- reserved_words = _kw
- modules = None # initialized to a set in __init__
- tab = ' '
- _kf = dict(chain(
- _known_functions.items(),
- [(k, 'math.' + v) for k, v in _known_functions_math.items()]
- ))
- _kc = {k: 'math.'+v for k, v in _known_constants_math.items()}
- _operators = {'and': 'and', 'or': 'or', 'not': 'not'}
- _default_settings = dict(
- CodePrinter._default_settings,
- user_functions={},
- precision=17,
- inline=True,
- fully_qualified_modules=True,
- contract=False,
- standard='python3',
- )
- def __init__(self, settings=None):
- super().__init__(settings)
- # Python standard handler
- std = self._settings['standard']
- if std is None:
- import sys
- std = 'python{}'.format(sys.version_info.major)
- if std != 'python3':
- raise ValueError('Only Python 3 is supported.')
- self.standard = std
- self.module_imports = defaultdict(set)
- # Known functions and constants handler
- self.known_functions = dict(self._kf, **(settings or {}).get(
- 'user_functions', {}))
- self.known_constants = dict(self._kc, **(settings or {}).get(
- 'user_constants', {}))
- def _declare_number_const(self, name, value):
- return "%s = %s" % (name, value)
- def _module_format(self, fqn, register=True):
- parts = fqn.split('.')
- if register and len(parts) > 1:
- self.module_imports['.'.join(parts[:-1])].add(parts[-1])
- if self._settings['fully_qualified_modules']:
- return fqn
- else:
- return fqn.split('(')[0].split('[')[0].split('.')[-1]
- def _format_code(self, lines):
- return lines
- def _get_statement(self, codestring):
- return "{}".format(codestring)
- def _get_comment(self, text):
- return " # {}".format(text)
- def _expand_fold_binary_op(self, op, args):
- """
- This method expands a fold on binary operations.
- ``functools.reduce`` is an example of a folded operation.
- For example, the expression
- `A + B + C + D`
- is folded into
- `((A + B) + C) + D`
- """
- if len(args) == 1:
- return self._print(args[0])
- else:
- return "%s(%s, %s)" % (
- self._module_format(op),
- self._expand_fold_binary_op(op, args[:-1]),
- self._print(args[-1]),
- )
- def _expand_reduce_binary_op(self, op, args):
- """
- This method expands a reductin on binary operations.
- Notice: this is NOT the same as ``functools.reduce``.
- For example, the expression
- `A + B + C + D`
- is reduced into:
- `(A + B) + (C + D)`
- """
- if len(args) == 1:
- return self._print(args[0])
- else:
- N = len(args)
- Nhalf = N // 2
- return "%s(%s, %s)" % (
- self._module_format(op),
- self._expand_reduce_binary_op(args[:Nhalf]),
- self._expand_reduce_binary_op(args[Nhalf:]),
- )
- def _get_einsum_string(self, subranks, contraction_indices):
- letters = self._get_letter_generator_for_einsum()
- contraction_string = ""
- counter = 0
- d = {j: min(i) for i in contraction_indices for j in i}
- indices = []
- for rank_arg in subranks:
- lindices = []
- for i in range(rank_arg):
- if counter in d:
- lindices.append(d[counter])
- else:
- lindices.append(counter)
- counter += 1
- indices.append(lindices)
- mapping = {}
- letters_free = []
- letters_dum = []
- for i in indices:
- for j in i:
- if j not in mapping:
- l = next(letters)
- mapping[j] = l
- else:
- l = mapping[j]
- contraction_string += l
- if j in d:
- if l not in letters_dum:
- letters_dum.append(l)
- else:
- letters_free.append(l)
- contraction_string += ","
- contraction_string = contraction_string[:-1]
- return contraction_string, letters_free, letters_dum
- def _print_NaN(self, expr):
- return "float('nan')"
- def _print_Infinity(self, expr):
- return "float('inf')"
- def _print_NegativeInfinity(self, expr):
- return "float('-inf')"
- def _print_ComplexInfinity(self, expr):
- return self._print_NaN(expr)
- def _print_Mod(self, expr):
- PREC = precedence(expr)
- return ('{} % {}'.format(*map(lambda x: self.parenthesize(x, PREC), expr.args)))
- def _print_Piecewise(self, expr):
- result = []
- i = 0
- for arg in expr.args:
- e = arg.expr
- c = arg.cond
- if i == 0:
- result.append('(')
- result.append('(')
- result.append(self._print(e))
- result.append(')')
- result.append(' if ')
- result.append(self._print(c))
- result.append(' else ')
- i += 1
- result = result[:-1]
- if result[-1] == 'True':
- result = result[:-2]
- result.append(')')
- else:
- result.append(' else None)')
- return ''.join(result)
- def _print_Relational(self, expr):
- "Relational printer for Equality and Unequality"
- op = {
- '==' :'equal',
- '!=' :'not_equal',
- '<' :'less',
- '<=' :'less_equal',
- '>' :'greater',
- '>=' :'greater_equal',
- }
- if expr.rel_op in op:
- lhs = self._print(expr.lhs)
- rhs = self._print(expr.rhs)
- return '({lhs} {op} {rhs})'.format(op=expr.rel_op, lhs=lhs, rhs=rhs)
- return super()._print_Relational(expr)
- def _print_ITE(self, expr):
- from sympy.functions.elementary.piecewise import Piecewise
- return self._print(expr.rewrite(Piecewise))
- def _print_Sum(self, expr):
- loops = (
- 'for {i} in range({a}, {b}+1)'.format(
- i=self._print(i),
- a=self._print(a),
- b=self._print(b))
- for i, a, b in expr.limits)
- return '(builtins.sum({function} {loops}))'.format(
- function=self._print(expr.function),
- loops=' '.join(loops))
- def _print_ImaginaryUnit(self, expr):
- return '1j'
- def _print_KroneckerDelta(self, expr):
- a, b = expr.args
- return '(1 if {a} == {b} else 0)'.format(
- a = self._print(a),
- b = self._print(b)
- )
- def _print_MatrixBase(self, expr):
- name = expr.__class__.__name__
- func = self.known_functions.get(name, name)
- return "%s(%s)" % (func, self._print(expr.tolist()))
- _print_SparseRepMatrix = \
- _print_MutableSparseMatrix = \
- _print_ImmutableSparseMatrix = \
- _print_Matrix = \
- _print_DenseMatrix = \
- _print_MutableDenseMatrix = \
- _print_ImmutableMatrix = \
- _print_ImmutableDenseMatrix = \
- lambda self, expr: self._print_MatrixBase(expr)
- def _indent_codestring(self, codestring):
- return '\n'.join([self.tab + line for line in codestring.split('\n')])
- def _print_FunctionDefinition(self, fd):
- body = '\n'.join(map(lambda arg: self._print(arg), fd.body))
- return "def {name}({parameters}):\n{body}".format(
- name=self._print(fd.name),
- parameters=', '.join([self._print(var.symbol) for var in fd.parameters]),
- body=self._indent_codestring(body)
- )
- def _print_While(self, whl):
- body = '\n'.join(map(lambda arg: self._print(arg), whl.body))
- return "while {cond}:\n{body}".format(
- cond=self._print(whl.condition),
- body=self._indent_codestring(body)
- )
- def _print_Declaration(self, decl):
- return '%s = %s' % (
- self._print(decl.variable.symbol),
- self._print(decl.variable.value)
- )
- def _print_Return(self, ret):
- arg, = ret.args
- return 'return %s' % self._print(arg)
- def _print_Print(self, prnt):
- print_args = ', '.join(map(lambda arg: self._print(arg), prnt.print_args))
- if prnt.format_string != None: # Must be '!= None', cannot be 'is not None'
- print_args = '{} % ({})'.format(
- self._print(prnt.format_string), print_args)
- if prnt.file != None: # Must be '!= None', cannot be 'is not None'
- print_args += ', file=%s' % self._print(prnt.file)
- return 'print(%s)' % print_args
- def _print_Stream(self, strm):
- if str(strm.name) == 'stdout':
- return self._module_format('sys.stdout')
- elif str(strm.name) == 'stderr':
- return self._module_format('sys.stderr')
- else:
- return self._print(strm.name)
- def _print_NoneToken(self, arg):
- return 'None'
- def _hprint_Pow(self, expr, rational=False, sqrt='math.sqrt'):
- """Printing helper function for ``Pow``
- Notes
- =====
- This only preprocesses the ``sqrt`` as math formatter
- Examples
- ========
- >>> from sympy import sqrt
- >>> from sympy.printing.pycode import PythonCodePrinter
- >>> from sympy.abc import x
- Python code printer automatically looks up ``math.sqrt``.
- >>> printer = PythonCodePrinter()
- >>> printer._hprint_Pow(sqrt(x), rational=True)
- 'x**(1/2)'
- >>> printer._hprint_Pow(sqrt(x), rational=False)
- 'math.sqrt(x)'
- >>> printer._hprint_Pow(1/sqrt(x), rational=True)
- 'x**(-1/2)'
- >>> printer._hprint_Pow(1/sqrt(x), rational=False)
- '1/math.sqrt(x)'
- Using sqrt from numpy or mpmath
- >>> printer._hprint_Pow(sqrt(x), sqrt='numpy.sqrt')
- 'numpy.sqrt(x)'
- >>> printer._hprint_Pow(sqrt(x), sqrt='mpmath.sqrt')
- 'mpmath.sqrt(x)'
- See Also
- ========
- sympy.printing.str.StrPrinter._print_Pow
- """
- PREC = precedence(expr)
- if expr.exp == S.Half and not rational:
- func = self._module_format(sqrt)
- arg = self._print(expr.base)
- return '{func}({arg})'.format(func=func, arg=arg)
- if expr.is_commutative:
- if -expr.exp is S.Half and not rational:
- func = self._module_format(sqrt)
- num = self._print(S.One)
- arg = self._print(expr.base)
- return "{num}/{func}({arg})".format(
- num=num, func=func, arg=arg)
- base_str = self.parenthesize(expr.base, PREC, strict=False)
- exp_str = self.parenthesize(expr.exp, PREC, strict=False)
- return "{}**{}".format(base_str, exp_str)
- class PythonCodePrinter(AbstractPythonCodePrinter):
- def _print_sign(self, e):
- return '(0.0 if {e} == 0 else {f}(1, {e}))'.format(
- f=self._module_format('math.copysign'), e=self._print(e.args[0]))
- def _print_Not(self, expr):
- PREC = precedence(expr)
- return self._operators['not'] + self.parenthesize(expr.args[0], PREC)
- def _print_Indexed(self, expr):
- base = expr.args[0]
- index = expr.args[1:]
- return "{}[{}]".format(str(base), ", ".join([self._print(ind) for ind in index]))
- def _print_Pow(self, expr, rational=False):
- return self._hprint_Pow(expr, rational=rational)
- def _print_Rational(self, expr):
- return '{}/{}'.format(expr.p, expr.q)
- def _print_Half(self, expr):
- return self._print_Rational(expr)
- def _print_frac(self, expr):
- from sympy.core.mod import Mod
- return self._print_Mod(Mod(expr.args[0], 1))
- def _print_Symbol(self, expr):
- name = super()._print_Symbol(expr)
- if name in self.reserved_words:
- if self._settings['error_on_reserved']:
- msg = ('This expression includes the symbol "{}" which is a '
- 'reserved keyword in this language.')
- raise ValueError(msg.format(name))
- return name + self._settings['reserved_word_suffix']
- elif '{' in name: # Remove curly braces from subscripted variables
- return name.replace('{', '').replace('}', '')
- else:
- return name
- _print_lowergamma = CodePrinter._print_not_supported
- _print_uppergamma = CodePrinter._print_not_supported
- _print_fresnelc = CodePrinter._print_not_supported
- _print_fresnels = CodePrinter._print_not_supported
- for k in PythonCodePrinter._kf:
- setattr(PythonCodePrinter, '_print_%s' % k, _print_known_func)
- for k in _known_constants_math:
- setattr(PythonCodePrinter, '_print_%s' % k, _print_known_const)
- def pycode(expr, **settings):
- """ Converts an expr to a string of Python code
- Parameters
- ==========
- expr : Expr
- A SymPy expression.
- fully_qualified_modules : bool
- Whether or not to write out full module names of functions
- (``math.sin`` vs. ``sin``). default: ``True``.
- standard : str or None, optional
- Only 'python3' (default) is supported.
- This parameter may be removed in the future.
- Examples
- ========
- >>> from sympy import pycode, tan, Symbol
- >>> pycode(tan(Symbol('x')) + 1)
- 'math.tan(x) + 1'
- """
- return PythonCodePrinter(settings).doprint(expr)
- _not_in_mpmath = 'log1p log2'.split()
- _in_mpmath = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_mpmath]
- _known_functions_mpmath = dict(_in_mpmath, **{
- 'beta': 'beta',
- 'frac': 'frac',
- 'fresnelc': 'fresnelc',
- 'fresnels': 'fresnels',
- 'sign': 'sign',
- 'loggamma': 'loggamma',
- 'hyper': 'hyper',
- 'meijerg': 'meijerg',
- 'besselj': 'besselj',
- 'bessely': 'bessely',
- 'besseli': 'besseli',
- 'besselk': 'besselk',
- })
- _known_constants_mpmath = {
- 'Exp1': 'e',
- 'Pi': 'pi',
- 'GoldenRatio': 'phi',
- 'EulerGamma': 'euler',
- 'Catalan': 'catalan',
- 'NaN': 'nan',
- 'Infinity': 'inf',
- 'NegativeInfinity': 'ninf'
- }
- def _unpack_integral_limits(integral_expr):
- """ helper function for _print_Integral that
- - accepts an Integral expression
- - returns a tuple of
- - a list variables of integration
- - a list of tuples of the upper and lower limits of integration
- """
- integration_vars = []
- limits = []
- for integration_range in integral_expr.limits:
- if len(integration_range) == 3:
- integration_var, lower_limit, upper_limit = integration_range
- else:
- raise NotImplementedError("Only definite integrals are supported")
- integration_vars.append(integration_var)
- limits.append((lower_limit, upper_limit))
- return integration_vars, limits
- class MpmathPrinter(PythonCodePrinter):
- """
- Lambda printer for mpmath which maintains precision for floats
- """
- printmethod = "_mpmathcode"
- language = "Python with mpmath"
- _kf = dict(chain(
- _known_functions.items(),
- [(k, 'mpmath.' + v) for k, v in _known_functions_mpmath.items()]
- ))
- _kc = {k: 'mpmath.'+v for k, v in _known_constants_mpmath.items()}
- def _print_Float(self, e):
- # XXX: This does not handle setting mpmath.mp.dps. It is assumed that
- # the caller of the lambdified function will have set it to sufficient
- # precision to match the Floats in the expression.
- # Remove 'mpz' if gmpy is installed.
- args = str(tuple(map(int, e._mpf_)))
- return '{func}({args})'.format(func=self._module_format('mpmath.mpf'), args=args)
- def _print_Rational(self, e):
- return "{func}({p})/{func}({q})".format(
- func=self._module_format('mpmath.mpf'),
- q=self._print(e.q),
- p=self._print(e.p)
- )
- def _print_Half(self, e):
- return self._print_Rational(e)
- def _print_uppergamma(self, e):
- return "{}({}, {}, {})".format(
- self._module_format('mpmath.gammainc'),
- self._print(e.args[0]),
- self._print(e.args[1]),
- self._module_format('mpmath.inf'))
- def _print_lowergamma(self, e):
- return "{}({}, 0, {})".format(
- self._module_format('mpmath.gammainc'),
- self._print(e.args[0]),
- self._print(e.args[1]))
- def _print_log2(self, e):
- return '{0}({1})/{0}(2)'.format(
- self._module_format('mpmath.log'), self._print(e.args[0]))
- def _print_log1p(self, e):
- return '{}({}+1)'.format(
- self._module_format('mpmath.log'), self._print(e.args[0]))
- def _print_Pow(self, expr, rational=False):
- return self._hprint_Pow(expr, rational=rational, sqrt='mpmath.sqrt')
- def _print_Integral(self, e):
- integration_vars, limits = _unpack_integral_limits(e)
- return "{}(lambda {}: {}, {})".format(
- self._module_format("mpmath.quad"),
- ", ".join(map(self._print, integration_vars)),
- self._print(e.args[0]),
- ", ".join("(%s, %s)" % tuple(map(self._print, l)) for l in limits))
- for k in MpmathPrinter._kf:
- setattr(MpmathPrinter, '_print_%s' % k, _print_known_func)
- for k in _known_constants_mpmath:
- setattr(MpmathPrinter, '_print_%s' % k, _print_known_const)
- class SymPyPrinter(AbstractPythonCodePrinter):
- language = "Python with SymPy"
- def _print_Function(self, expr):
- mod = expr.func.__module__ or ''
- return '%s(%s)' % (self._module_format(mod + ('.' if mod else '') + expr.func.__name__),
- ', '.join(map(lambda arg: self._print(arg), expr.args)))
- def _print_Pow(self, expr, rational=False):
- return self._hprint_Pow(expr, rational=rational, sqrt='sympy.sqrt')
|