123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353 |
- """
- Mathematica code printer
- """
- from typing import Any, Dict as tDict, Set as tSet, Tuple as tTuple
- from sympy.core import Basic, Expr, Float
- from sympy.core.sorting import default_sort_key
- from sympy.printing.codeprinter import CodePrinter
- from sympy.printing.precedence import precedence
- # Used in MCodePrinter._print_Function(self)
- known_functions = {
- "exp": [(lambda x: True, "Exp")],
- "log": [(lambda x: True, "Log")],
- "sin": [(lambda x: True, "Sin")],
- "cos": [(lambda x: True, "Cos")],
- "tan": [(lambda x: True, "Tan")],
- "cot": [(lambda x: True, "Cot")],
- "sec": [(lambda x: True, "Sec")],
- "csc": [(lambda x: True, "Csc")],
- "asin": [(lambda x: True, "ArcSin")],
- "acos": [(lambda x: True, "ArcCos")],
- "atan": [(lambda x: True, "ArcTan")],
- "acot": [(lambda x: True, "ArcCot")],
- "asec": [(lambda x: True, "ArcSec")],
- "acsc": [(lambda x: True, "ArcCsc")],
- "atan2": [(lambda *x: True, "ArcTan")],
- "sinh": [(lambda x: True, "Sinh")],
- "cosh": [(lambda x: True, "Cosh")],
- "tanh": [(lambda x: True, "Tanh")],
- "coth": [(lambda x: True, "Coth")],
- "sech": [(lambda x: True, "Sech")],
- "csch": [(lambda x: True, "Csch")],
- "asinh": [(lambda x: True, "ArcSinh")],
- "acosh": [(lambda x: True, "ArcCosh")],
- "atanh": [(lambda x: True, "ArcTanh")],
- "acoth": [(lambda x: True, "ArcCoth")],
- "asech": [(lambda x: True, "ArcSech")],
- "acsch": [(lambda x: True, "ArcCsch")],
- "sinc": [(lambda x: True, "Sinc")],
- "conjugate": [(lambda x: True, "Conjugate")],
- "Max": [(lambda *x: True, "Max")],
- "Min": [(lambda *x: True, "Min")],
- "erf": [(lambda x: True, "Erf")],
- "erf2": [(lambda *x: True, "Erf")],
- "erfc": [(lambda x: True, "Erfc")],
- "erfi": [(lambda x: True, "Erfi")],
- "erfinv": [(lambda x: True, "InverseErf")],
- "erfcinv": [(lambda x: True, "InverseErfc")],
- "erf2inv": [(lambda *x: True, "InverseErf")],
- "expint": [(lambda *x: True, "ExpIntegralE")],
- "Ei": [(lambda x: True, "ExpIntegralEi")],
- "fresnelc": [(lambda x: True, "FresnelC")],
- "fresnels": [(lambda x: True, "FresnelS")],
- "gamma": [(lambda x: True, "Gamma")],
- "uppergamma": [(lambda *x: True, "Gamma")],
- "polygamma": [(lambda *x: True, "PolyGamma")],
- "loggamma": [(lambda x: True, "LogGamma")],
- "beta": [(lambda *x: True, "Beta")],
- "Ci": [(lambda x: True, "CosIntegral")],
- "Si": [(lambda x: True, "SinIntegral")],
- "Chi": [(lambda x: True, "CoshIntegral")],
- "Shi": [(lambda x: True, "SinhIntegral")],
- "li": [(lambda x: True, "LogIntegral")],
- "factorial": [(lambda x: True, "Factorial")],
- "factorial2": [(lambda x: True, "Factorial2")],
- "subfactorial": [(lambda x: True, "Subfactorial")],
- "catalan": [(lambda x: True, "CatalanNumber")],
- "harmonic": [(lambda *x: True, "HarmonicNumber")],
- "lucas": [(lambda x: True, "LucasL")],
- "RisingFactorial": [(lambda *x: True, "Pochhammer")],
- "FallingFactorial": [(lambda *x: True, "FactorialPower")],
- "laguerre": [(lambda *x: True, "LaguerreL")],
- "assoc_laguerre": [(lambda *x: True, "LaguerreL")],
- "hermite": [(lambda *x: True, "HermiteH")],
- "jacobi": [(lambda *x: True, "JacobiP")],
- "gegenbauer": [(lambda *x: True, "GegenbauerC")],
- "chebyshevt": [(lambda *x: True, "ChebyshevT")],
- "chebyshevu": [(lambda *x: True, "ChebyshevU")],
- "legendre": [(lambda *x: True, "LegendreP")],
- "assoc_legendre": [(lambda *x: True, "LegendreP")],
- "mathieuc": [(lambda *x: True, "MathieuC")],
- "mathieus": [(lambda *x: True, "MathieuS")],
- "mathieucprime": [(lambda *x: True, "MathieuCPrime")],
- "mathieusprime": [(lambda *x: True, "MathieuSPrime")],
- "stieltjes": [(lambda x: True, "StieltjesGamma")],
- "elliptic_e": [(lambda *x: True, "EllipticE")],
- "elliptic_f": [(lambda *x: True, "EllipticE")],
- "elliptic_k": [(lambda x: True, "EllipticK")],
- "elliptic_pi": [(lambda *x: True, "EllipticPi")],
- "zeta": [(lambda *x: True, "Zeta")],
- "dirichlet_eta": [(lambda x: True, "DirichletEta")],
- "riemann_xi": [(lambda x: True, "RiemannXi")],
- "besseli": [(lambda *x: True, "BesselI")],
- "besselj": [(lambda *x: True, "BesselJ")],
- "besselk": [(lambda *x: True, "BesselK")],
- "bessely": [(lambda *x: True, "BesselY")],
- "hankel1": [(lambda *x: True, "HankelH1")],
- "hankel2": [(lambda *x: True, "HankelH2")],
- "airyai": [(lambda x: True, "AiryAi")],
- "airybi": [(lambda x: True, "AiryBi")],
- "airyaiprime": [(lambda x: True, "AiryAiPrime")],
- "airybiprime": [(lambda x: True, "AiryBiPrime")],
- "polylog": [(lambda *x: True, "PolyLog")],
- "lerchphi": [(lambda *x: True, "LerchPhi")],
- "gcd": [(lambda *x: True, "GCD")],
- "lcm": [(lambda *x: True, "LCM")],
- "jn": [(lambda *x: True, "SphericalBesselJ")],
- "yn": [(lambda *x: True, "SphericalBesselY")],
- "hyper": [(lambda *x: True, "HypergeometricPFQ")],
- "meijerg": [(lambda *x: True, "MeijerG")],
- "appellf1": [(lambda *x: True, "AppellF1")],
- "DiracDelta": [(lambda x: True, "DiracDelta")],
- "Heaviside": [(lambda x: True, "HeavisideTheta")],
- "KroneckerDelta": [(lambda *x: True, "KroneckerDelta")],
- "sqrt": [(lambda x: True, "Sqrt")], # For automatic rewrites
- }
- class MCodePrinter(CodePrinter):
- """A printer to convert Python expressions to
- strings of the Wolfram's Mathematica code
- """
- printmethod = "_mcode"
- language = "Wolfram Language"
- _default_settings = {
- 'order': None,
- 'full_prec': 'auto',
- 'precision': 15,
- 'user_functions': {},
- 'human': True,
- 'allow_unknown_functions': False,
- } # type: tDict[str, Any]
- _number_symbols = set() # type: tSet[tTuple[Expr, Float]]
- _not_supported = set() # type: tSet[Basic]
- def __init__(self, settings={}):
- """Register function mappings supplied by user"""
- CodePrinter.__init__(self, settings)
- self.known_functions = dict(known_functions)
- userfuncs = settings.get('user_functions', {}).copy()
- for k, v in userfuncs.items():
- if not isinstance(v, list):
- userfuncs[k] = [(lambda *x: True, v)]
- self.known_functions.update(userfuncs)
- def _format_code(self, lines):
- return lines
- def _print_Pow(self, expr):
- PREC = precedence(expr)
- return '%s^%s' % (self.parenthesize(expr.base, PREC),
- self.parenthesize(expr.exp, PREC))
- def _print_Mul(self, expr):
- PREC = precedence(expr)
- c, nc = expr.args_cnc()
- res = super()._print_Mul(expr.func(*c))
- if nc:
- res += '*'
- res += '**'.join(self.parenthesize(a, PREC) for a in nc)
- return res
- def _print_Relational(self, expr):
- lhs_code = self._print(expr.lhs)
- rhs_code = self._print(expr.rhs)
- op = expr.rel_op
- return "{} {} {}".format(lhs_code, op, rhs_code)
- # Primitive numbers
- def _print_Zero(self, expr):
- return '0'
- def _print_One(self, expr):
- return '1'
- def _print_NegativeOne(self, expr):
- return '-1'
- def _print_Half(self, expr):
- return '1/2'
- def _print_ImaginaryUnit(self, expr):
- return 'I'
- # Infinity and invalid numbers
- def _print_Infinity(self, expr):
- return 'Infinity'
- def _print_NegativeInfinity(self, expr):
- return '-Infinity'
- def _print_ComplexInfinity(self, expr):
- return 'ComplexInfinity'
- def _print_NaN(self, expr):
- return 'Indeterminate'
- # Mathematical constants
- def _print_Exp1(self, expr):
- return 'E'
- def _print_Pi(self, expr):
- return 'Pi'
- def _print_GoldenRatio(self, expr):
- return 'GoldenRatio'
- def _print_TribonacciConstant(self, expr):
- expanded = expr.expand(func=True)
- PREC = precedence(expr)
- return self.parenthesize(expanded, PREC)
- def _print_EulerGamma(self, expr):
- return 'EulerGamma'
- def _print_Catalan(self, expr):
- return 'Catalan'
- def _print_list(self, expr):
- return '{' + ', '.join(self.doprint(a) for a in expr) + '}'
- _print_tuple = _print_list
- _print_Tuple = _print_list
- def _print_ImmutableDenseMatrix(self, expr):
- return self.doprint(expr.tolist())
- def _print_ImmutableSparseMatrix(self, expr):
- def print_rule(pos, val):
- return '{} -> {}'.format(
- self.doprint((pos[0]+1, pos[1]+1)), self.doprint(val))
- def print_data():
- items = sorted(expr.todok().items(), key=default_sort_key)
- return '{' + \
- ', '.join(print_rule(k, v) for k, v in items) + \
- '}'
- def print_dims():
- return self.doprint(expr.shape)
- return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
- def _print_ImmutableDenseNDimArray(self, expr):
- return self.doprint(expr.tolist())
- def _print_ImmutableSparseNDimArray(self, expr):
- def print_string_list(string_list):
- return '{' + ', '.join(a for a in string_list) + '}'
- def to_mathematica_index(*args):
- """Helper function to change Python style indexing to
- Pathematica indexing.
- Python indexing (0, 1 ... n-1)
- -> Mathematica indexing (1, 2 ... n)
- """
- return tuple(i + 1 for i in args)
- def print_rule(pos, val):
- """Helper function to print a rule of Mathematica"""
- return '{} -> {}'.format(self.doprint(pos), self.doprint(val))
- def print_data():
- """Helper function to print data part of Mathematica
- sparse array.
- It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
- from
- https://reference.wolfram.com/language/ref/SparseArray.html
- ``data`` must be formatted with rule.
- """
- return print_string_list(
- [print_rule(
- to_mathematica_index(*(expr._get_tuple_index(key))),
- value)
- for key, value in sorted(expr._sparse_array.items())]
- )
- def print_dims():
- """Helper function to print dimensions part of Mathematica
- sparse array.
- It uses the fourth notation ``SparseArray[data,{d1,d2,...}]``
- from
- https://reference.wolfram.com/language/ref/SparseArray.html
- """
- return self.doprint(expr.shape)
- return 'SparseArray[{}, {}]'.format(print_data(), print_dims())
- def _print_Function(self, expr):
- if expr.func.__name__ in self.known_functions:
- cond_mfunc = self.known_functions[expr.func.__name__]
- for cond, mfunc in cond_mfunc:
- if cond(*expr.args):
- return "%s[%s]" % (mfunc, self.stringify(expr.args, ", "))
- elif expr.func.__name__ in self._rewriteable_functions:
- # Simple rewrite to supported function possible
- target_f, required_fs = self._rewriteable_functions[expr.func.__name__]
- if self._can_print(target_f) and all(self._can_print(f) for f in required_fs):
- return self._print(expr.rewrite(target_f))
- return expr.func.__name__ + "[%s]" % self.stringify(expr.args, ", ")
- _print_MinMaxBase = _print_Function
- def _print_LambertW(self, expr):
- if len(expr.args) == 1:
- return "ProductLog[{}]".format(self._print(expr.args[0]))
- return "ProductLog[{}, {}]".format(
- self._print(expr.args[1]), self._print(expr.args[0]))
- def _print_Integral(self, expr):
- if len(expr.variables) == 1 and not expr.limits[0][1:]:
- args = [expr.args[0], expr.variables[0]]
- else:
- args = expr.args
- return "Hold[Integrate[" + ', '.join(self.doprint(a) for a in args) + "]]"
- def _print_Sum(self, expr):
- return "Hold[Sum[" + ', '.join(self.doprint(a) for a in expr.args) + "]]"
- def _print_Derivative(self, expr):
- dexpr = expr.expr
- dvars = [i[0] if i[1] == 1 else i for i in expr.variable_count]
- return "Hold[D[" + ', '.join(self.doprint(a) for a in [dexpr] + dvars) + "]]"
- def _get_comment(self, text):
- return "(* {} *)".format(text)
- def mathematica_code(expr, **settings):
- r"""Converts an expr to a string of the Wolfram Mathematica code
- Examples
- ========
- >>> from sympy import mathematica_code as mcode, symbols, sin
- >>> x = symbols('x')
- >>> mcode(sin(x).series(x).removeO())
- '(1/120)*x^5 - 1/6*x^3 + x'
- """
- return MCodePrinter(settings).doprint(expr)
|