123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293 |
- import keyword as kw
- import sympy
- from .repr import ReprPrinter
- from .str import StrPrinter
- # A list of classes that should be printed using StrPrinter
- STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity",
- "Pow", "Zero")
- class PythonPrinter(ReprPrinter, StrPrinter):
- """A printer which converts an expression into its Python interpretation."""
- def __init__(self, settings=None):
- super().__init__(settings)
- self.symbols = []
- self.functions = []
- # Create print methods for classes that should use StrPrinter instead
- # of ReprPrinter.
- for name in STRPRINT:
- f_name = "_print_%s" % name
- f = getattr(StrPrinter, f_name)
- setattr(PythonPrinter, f_name, f)
- def _print_Function(self, expr):
- func = expr.func.__name__
- if not hasattr(sympy, func) and func not in self.functions:
- self.functions.append(func)
- return StrPrinter._print_Function(self, expr)
- # procedure (!) for defining symbols which have be defined in print_python()
- def _print_Symbol(self, expr):
- symbol = self._str(expr)
- if symbol not in self.symbols:
- self.symbols.append(symbol)
- return StrPrinter._print_Symbol(self, expr)
- def _print_module(self, expr):
- raise ValueError('Modules in the expression are unacceptable')
- def python(expr, **settings):
- """Return Python interpretation of passed expression
- (can be passed to the exec() function without any modifications)"""
- printer = PythonPrinter(settings)
- exprp = printer.doprint(expr)
- result = ''
- # Returning found symbols and functions
- renamings = {}
- for symbolname in printer.symbols:
- # Remove curly braces from subscripted variables
- if '{' in symbolname:
- newsymbolname = symbolname.replace('{', '').replace('}', '')
- renamings[sympy.Symbol(symbolname)] = newsymbolname
- else:
- newsymbolname = symbolname
- # Escape symbol names that are reserved Python keywords
- if kw.iskeyword(newsymbolname):
- while True:
- newsymbolname += "_"
- if (newsymbolname not in printer.symbols and
- newsymbolname not in printer.functions):
- renamings[sympy.Symbol(
- symbolname)] = sympy.Symbol(newsymbolname)
- break
- result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n'
- for functionname in printer.functions:
- newfunctionname = functionname
- # Escape function names that are reserved Python keywords
- if kw.iskeyword(newfunctionname):
- while True:
- newfunctionname += "_"
- if (newfunctionname not in printer.symbols and
- newfunctionname not in printer.functions):
- renamings[sympy.Function(
- functionname)] = sympy.Function(newfunctionname)
- break
- result += newfunctionname + ' = Function(\'' + functionname + '\')\n'
- if renamings:
- exprp = expr.subs(renamings)
- result += 'e = ' + printer._str(exprp)
- return result
- def print_python(expr, **settings):
- """Print output of python() function"""
- print(python(expr, **settings))
|