123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210 |
- from .pycode import (
- PythonCodePrinter,
- MpmathPrinter,
- )
- from .numpy import NumPyPrinter # NumPyPrinter is imported for backward compatibility
- from sympy.core.sorting import default_sort_key
- __all__ = [
- 'PythonCodePrinter',
- 'MpmathPrinter', # MpmathPrinter is published for backward compatibility
- 'NumPyPrinter',
- 'LambdaPrinter',
- 'NumPyPrinter',
- 'IntervalPrinter',
- 'lambdarepr',
- ]
- class LambdaPrinter(PythonCodePrinter):
- """
- This printer converts expressions into strings that can be used by
- lambdify.
- """
- printmethod = "_lambdacode"
- def _print_And(self, expr):
- result = ['(']
- for arg in sorted(expr.args, key=default_sort_key):
- result.extend(['(', self._print(arg), ')'])
- result.append(' and ')
- result = result[:-1]
- result.append(')')
- return ''.join(result)
- def _print_Or(self, expr):
- result = ['(']
- for arg in sorted(expr.args, key=default_sort_key):
- result.extend(['(', self._print(arg), ')'])
- result.append(' or ')
- result = result[:-1]
- result.append(')')
- return ''.join(result)
- def _print_Not(self, expr):
- result = ['(', 'not (', self._print(expr.args[0]), '))']
- return ''.join(result)
- def _print_BooleanTrue(self, expr):
- return "True"
- def _print_BooleanFalse(self, expr):
- return "False"
- def _print_ITE(self, expr):
- result = [
- '((', self._print(expr.args[1]),
- ') if (', self._print(expr.args[0]),
- ') else (', self._print(expr.args[2]), '))'
- ]
- return ''.join(result)
- def _print_NumberSymbol(self, expr):
- return str(expr)
- def _print_Pow(self, expr, **kwargs):
- # XXX Temporary workaround. Should Python math printer be
- # isolated from PythonCodePrinter?
- return super(PythonCodePrinter, self)._print_Pow(expr, **kwargs)
- # numexpr works by altering the string passed to numexpr.evaluate
- # rather than by populating a namespace. Thus a special printer...
- class NumExprPrinter(LambdaPrinter):
- # key, value pairs correspond to SymPy name and numexpr name
- # functions not appearing in this dict will raise a TypeError
- printmethod = "_numexprcode"
- _numexpr_functions = {
- 'sin' : 'sin',
- 'cos' : 'cos',
- 'tan' : 'tan',
- 'asin': 'arcsin',
- 'acos': 'arccos',
- 'atan': 'arctan',
- 'atan2' : 'arctan2',
- 'sinh' : 'sinh',
- 'cosh' : 'cosh',
- 'tanh' : 'tanh',
- 'asinh': 'arcsinh',
- 'acosh': 'arccosh',
- 'atanh': 'arctanh',
- 'ln' : 'log',
- 'log': 'log',
- 'exp': 'exp',
- 'sqrt' : 'sqrt',
- 'Abs' : 'abs',
- 'conjugate' : 'conj',
- 'im' : 'imag',
- 're' : 'real',
- 'where' : 'where',
- 'complex' : 'complex',
- 'contains' : 'contains',
- }
- def _print_ImaginaryUnit(self, expr):
- return '1j'
- def _print_seq(self, seq, delimiter=', '):
- # simplified _print_seq taken from pretty.py
- s = [self._print(item) for item in seq]
- if s:
- return delimiter.join(s)
- else:
- return ""
- def _print_Function(self, e):
- func_name = e.func.__name__
- nstr = self._numexpr_functions.get(func_name, None)
- if nstr is None:
- # check for implemented_function
- if hasattr(e, '_imp_'):
- return "(%s)" % self._print(e._imp_(*e.args))
- else:
- raise TypeError("numexpr does not support function '%s'" %
- func_name)
- return "%s(%s)" % (nstr, self._print_seq(e.args))
- def _print_Piecewise(self, expr):
- "Piecewise function printer"
- exprs = [self._print(arg.expr) for arg in expr.args]
- conds = [self._print(arg.cond) for arg in expr.args]
- # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
- # it will behave the same as passing the 'default' kwarg to select()
- # *as long as* it is the last element in expr.args.
- # If this is not the case, it may be triggered prematurely.
- ans = []
- parenthesis_count = 0
- is_last_cond_True = False
- for cond, expr in zip(conds, exprs):
- if cond == 'True':
- ans.append(expr)
- is_last_cond_True = True
- break
- else:
- ans.append('where(%s, %s, ' % (cond, expr))
- parenthesis_count += 1
- if not is_last_cond_True:
- # simplest way to put a nan but raises
- # 'RuntimeWarning: invalid value encountered in log'
- ans.append('log(-1)')
- return ''.join(ans) + ')' * parenthesis_count
- def _print_ITE(self, expr):
- from sympy.functions.elementary.piecewise import Piecewise
- return self._print(expr.rewrite(Piecewise))
- def blacklisted(self, expr):
- raise TypeError("numexpr cannot be used with %s" %
- expr.__class__.__name__)
- # blacklist all Matrix printing
- _print_SparseRepMatrix = \
- _print_MutableSparseMatrix = \
- _print_ImmutableSparseMatrix = \
- _print_Matrix = \
- _print_DenseMatrix = \
- _print_MutableDenseMatrix = \
- _print_ImmutableMatrix = \
- _print_ImmutableDenseMatrix = \
- blacklisted
- # blacklist some Python expressions
- _print_list = \
- _print_tuple = \
- _print_Tuple = \
- _print_dict = \
- _print_Dict = \
- blacklisted
- def doprint(self, expr):
- lstr = super().doprint(expr)
- return "evaluate('%s', truediv=True)" % lstr
- class IntervalPrinter(MpmathPrinter, LambdaPrinter):
- """Use ``lambda`` printer but print numbers as ``mpi`` intervals. """
- def _print_Integer(self, expr):
- return "mpi('%s')" % super(PythonCodePrinter, self)._print_Integer(expr)
- def _print_Rational(self, expr):
- return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
- def _print_Half(self, expr):
- return "mpi('%s')" % super(PythonCodePrinter, self)._print_Rational(expr)
- def _print_Pow(self, expr):
- return super(MpmathPrinter, self)._print_Pow(expr, rational=True)
- for k in NumExprPrinter._numexpr_functions:
- setattr(NumExprPrinter, '_print_%s' % k, NumExprPrinter._print_Function)
- def lambdarepr(expr, **settings):
- """
- Returns a string usable for lambdifying.
- """
- return LambdaPrinter(settings).doprint(expr)
|