python.py 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293
  1. import keyword as kw
  2. import sympy
  3. from .repr import ReprPrinter
  4. from .str import StrPrinter
  5. # A list of classes that should be printed using StrPrinter
  6. STRPRINT = ("Add", "Infinity", "Integer", "Mul", "NegativeInfinity",
  7. "Pow", "Zero")
  8. class PythonPrinter(ReprPrinter, StrPrinter):
  9. """A printer which converts an expression into its Python interpretation."""
  10. def __init__(self, settings=None):
  11. super().__init__(settings)
  12. self.symbols = []
  13. self.functions = []
  14. # Create print methods for classes that should use StrPrinter instead
  15. # of ReprPrinter.
  16. for name in STRPRINT:
  17. f_name = "_print_%s" % name
  18. f = getattr(StrPrinter, f_name)
  19. setattr(PythonPrinter, f_name, f)
  20. def _print_Function(self, expr):
  21. func = expr.func.__name__
  22. if not hasattr(sympy, func) and func not in self.functions:
  23. self.functions.append(func)
  24. return StrPrinter._print_Function(self, expr)
  25. # procedure (!) for defining symbols which have be defined in print_python()
  26. def _print_Symbol(self, expr):
  27. symbol = self._str(expr)
  28. if symbol not in self.symbols:
  29. self.symbols.append(symbol)
  30. return StrPrinter._print_Symbol(self, expr)
  31. def _print_module(self, expr):
  32. raise ValueError('Modules in the expression are unacceptable')
  33. def python(expr, **settings):
  34. """Return Python interpretation of passed expression
  35. (can be passed to the exec() function without any modifications)"""
  36. printer = PythonPrinter(settings)
  37. exprp = printer.doprint(expr)
  38. result = ''
  39. # Returning found symbols and functions
  40. renamings = {}
  41. for symbolname in printer.symbols:
  42. # Remove curly braces from subscripted variables
  43. if '{' in symbolname:
  44. newsymbolname = symbolname.replace('{', '').replace('}', '')
  45. renamings[sympy.Symbol(symbolname)] = newsymbolname
  46. else:
  47. newsymbolname = symbolname
  48. # Escape symbol names that are reserved Python keywords
  49. if kw.iskeyword(newsymbolname):
  50. while True:
  51. newsymbolname += "_"
  52. if (newsymbolname not in printer.symbols and
  53. newsymbolname not in printer.functions):
  54. renamings[sympy.Symbol(
  55. symbolname)] = sympy.Symbol(newsymbolname)
  56. break
  57. result += newsymbolname + ' = Symbol(\'' + symbolname + '\')\n'
  58. for functionname in printer.functions:
  59. newfunctionname = functionname
  60. # Escape function names that are reserved Python keywords
  61. if kw.iskeyword(newfunctionname):
  62. while True:
  63. newfunctionname += "_"
  64. if (newfunctionname not in printer.symbols and
  65. newfunctionname not in printer.functions):
  66. renamings[sympy.Function(
  67. functionname)] = sympy.Function(newfunctionname)
  68. break
  69. result += newfunctionname + ' = Function(\'' + functionname + '\')\n'
  70. if renamings:
  71. exprp = expr.subs(renamings)
  72. result += 'e = ' + printer._str(exprp)
  73. return result
  74. def print_python(expr, **settings):
  75. """Print output of python() function"""
  76. print(python(expr, **settings))