repr.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347
  1. """
  2. A Printer for generating executable code.
  3. The most important function here is srepr that returns a string so that the
  4. relation eval(srepr(expr))=expr holds in an appropriate environment.
  5. """
  6. from typing import Any, Dict as tDict
  7. from sympy.core.function import AppliedUndef
  8. from sympy.core.mul import Mul
  9. from mpmath.libmp import repr_dps, to_str as mlib_to_str
  10. from .printer import Printer, print_function
  11. class ReprPrinter(Printer):
  12. printmethod = "_sympyrepr"
  13. _default_settings = {
  14. "order": None,
  15. "perm_cyclic" : True,
  16. } # type: tDict[str, Any]
  17. def reprify(self, args, sep):
  18. """
  19. Prints each item in `args` and joins them with `sep`.
  20. """
  21. return sep.join([self.doprint(item) for item in args])
  22. def emptyPrinter(self, expr):
  23. """
  24. The fallback printer.
  25. """
  26. if isinstance(expr, str):
  27. return expr
  28. elif hasattr(expr, "__srepr__"):
  29. return expr.__srepr__()
  30. elif hasattr(expr, "args") and hasattr(expr.args, "__iter__"):
  31. l = []
  32. for o in expr.args:
  33. l.append(self._print(o))
  34. return expr.__class__.__name__ + '(%s)' % ', '.join(l)
  35. elif hasattr(expr, "__module__") and hasattr(expr, "__name__"):
  36. return "<'%s.%s'>" % (expr.__module__, expr.__name__)
  37. else:
  38. return str(expr)
  39. def _print_Add(self, expr, order=None):
  40. args = self._as_ordered_terms(expr, order=order)
  41. nargs = len(args)
  42. args = map(self._print, args)
  43. clsname = type(expr).__name__
  44. if nargs > 255: # Issue #10259, Python < 3.7
  45. return clsname + "(*[%s])" % ", ".join(args)
  46. return clsname + "(%s)" % ", ".join(args)
  47. def _print_Cycle(self, expr):
  48. return expr.__repr__()
  49. def _print_Permutation(self, expr):
  50. from sympy.combinatorics.permutations import Permutation, Cycle
  51. from sympy.utilities.exceptions import sympy_deprecation_warning
  52. perm_cyclic = Permutation.print_cyclic
  53. if perm_cyclic is not None:
  54. sympy_deprecation_warning(
  55. f"""
  56. Setting Permutation.print_cyclic is deprecated. Instead use
  57. init_printing(perm_cyclic={perm_cyclic}).
  58. """,
  59. deprecated_since_version="1.6",
  60. active_deprecations_target="deprecated-permutation-print_cyclic",
  61. stacklevel=7,
  62. )
  63. else:
  64. perm_cyclic = self._settings.get("perm_cyclic", True)
  65. if perm_cyclic:
  66. if not expr.size:
  67. return 'Permutation()'
  68. # before taking Cycle notation, see if the last element is
  69. # a singleton and move it to the head of the string
  70. s = Cycle(expr)(expr.size - 1).__repr__()[len('Cycle'):]
  71. last = s.rfind('(')
  72. if not last == 0 and ',' not in s[last:]:
  73. s = s[last:] + s[:last]
  74. return 'Permutation%s' %s
  75. else:
  76. s = expr.support()
  77. if not s:
  78. if expr.size < 5:
  79. return 'Permutation(%s)' % str(expr.array_form)
  80. return 'Permutation([], size=%s)' % expr.size
  81. trim = str(expr.array_form[:s[-1] + 1]) + ', size=%s' % expr.size
  82. use = full = str(expr.array_form)
  83. if len(trim) < len(full):
  84. use = trim
  85. return 'Permutation(%s)' % use
  86. def _print_Function(self, expr):
  87. r = self._print(expr.func)
  88. r += '(%s)' % ', '.join([self._print(a) for a in expr.args])
  89. return r
  90. def _print_Heaviside(self, expr):
  91. # Same as _print_Function but uses pargs to suppress default value for
  92. # 2nd arg.
  93. r = self._print(expr.func)
  94. r += '(%s)' % ', '.join([self._print(a) for a in expr.pargs])
  95. return r
  96. def _print_FunctionClass(self, expr):
  97. if issubclass(expr, AppliedUndef):
  98. return 'Function(%r)' % (expr.__name__)
  99. else:
  100. return expr.__name__
  101. def _print_Half(self, expr):
  102. return 'Rational(1, 2)'
  103. def _print_RationalConstant(self, expr):
  104. return str(expr)
  105. def _print_AtomicExpr(self, expr):
  106. return str(expr)
  107. def _print_NumberSymbol(self, expr):
  108. return str(expr)
  109. def _print_Integer(self, expr):
  110. return 'Integer(%i)' % expr.p
  111. def _print_Complexes(self, expr):
  112. return 'Complexes'
  113. def _print_Integers(self, expr):
  114. return 'Integers'
  115. def _print_Naturals(self, expr):
  116. return 'Naturals'
  117. def _print_Naturals0(self, expr):
  118. return 'Naturals0'
  119. def _print_Rationals(self, expr):
  120. return 'Rationals'
  121. def _print_Reals(self, expr):
  122. return 'Reals'
  123. def _print_EmptySet(self, expr):
  124. return 'EmptySet'
  125. def _print_UniversalSet(self, expr):
  126. return 'UniversalSet'
  127. def _print_EmptySequence(self, expr):
  128. return 'EmptySequence'
  129. def _print_list(self, expr):
  130. return "[%s]" % self.reprify(expr, ", ")
  131. def _print_dict(self, expr):
  132. sep = ", "
  133. dict_kvs = ["%s: %s" % (self.doprint(key), self.doprint(value)) for key, value in expr.items()]
  134. return "{%s}" % sep.join(dict_kvs)
  135. def _print_set(self, expr):
  136. if not expr:
  137. return "set()"
  138. return "{%s}" % self.reprify(expr, ", ")
  139. def _print_MatrixBase(self, expr):
  140. # special case for some empty matrices
  141. if (expr.rows == 0) ^ (expr.cols == 0):
  142. return '%s(%s, %s, %s)' % (expr.__class__.__name__,
  143. self._print(expr.rows),
  144. self._print(expr.cols),
  145. self._print([]))
  146. l = []
  147. for i in range(expr.rows):
  148. l.append([])
  149. for j in range(expr.cols):
  150. l[-1].append(expr[i, j])
  151. return '%s(%s)' % (expr.__class__.__name__, self._print(l))
  152. def _print_BooleanTrue(self, expr):
  153. return "true"
  154. def _print_BooleanFalse(self, expr):
  155. return "false"
  156. def _print_NaN(self, expr):
  157. return "nan"
  158. def _print_Mul(self, expr, order=None):
  159. if self.order not in ('old', 'none'):
  160. args = expr.as_ordered_factors()
  161. else:
  162. # use make_args in case expr was something like -x -> x
  163. args = Mul.make_args(expr)
  164. nargs = len(args)
  165. args = map(self._print, args)
  166. clsname = type(expr).__name__
  167. if nargs > 255: # Issue #10259, Python < 3.7
  168. return clsname + "(*[%s])" % ", ".join(args)
  169. return clsname + "(%s)" % ", ".join(args)
  170. def _print_Rational(self, expr):
  171. return 'Rational(%s, %s)' % (self._print(expr.p), self._print(expr.q))
  172. def _print_PythonRational(self, expr):
  173. return "%s(%d, %d)" % (expr.__class__.__name__, expr.p, expr.q)
  174. def _print_Fraction(self, expr):
  175. return 'Fraction(%s, %s)' % (self._print(expr.numerator), self._print(expr.denominator))
  176. def _print_Float(self, expr):
  177. r = mlib_to_str(expr._mpf_, repr_dps(expr._prec))
  178. return "%s('%s', precision=%i)" % (expr.__class__.__name__, r, expr._prec)
  179. def _print_Sum2(self, expr):
  180. return "Sum2(%s, (%s, %s, %s))" % (self._print(expr.f), self._print(expr.i),
  181. self._print(expr.a), self._print(expr.b))
  182. def _print_Str(self, s):
  183. return "%s(%s)" % (s.__class__.__name__, self._print(s.name))
  184. def _print_Symbol(self, expr):
  185. d = expr._assumptions.generator
  186. # print the dummy_index like it was an assumption
  187. if expr.is_Dummy:
  188. d['dummy_index'] = expr.dummy_index
  189. if d == {}:
  190. return "%s(%s)" % (expr.__class__.__name__, self._print(expr.name))
  191. else:
  192. attr = ['%s=%s' % (k, v) for k, v in d.items()]
  193. return "%s(%s, %s)" % (expr.__class__.__name__,
  194. self._print(expr.name), ', '.join(attr))
  195. def _print_CoordinateSymbol(self, expr):
  196. d = expr._assumptions.generator
  197. if d == {}:
  198. return "%s(%s, %s)" % (
  199. expr.__class__.__name__,
  200. self._print(expr.coord_sys),
  201. self._print(expr.index)
  202. )
  203. else:
  204. attr = ['%s=%s' % (k, v) for k, v in d.items()]
  205. return "%s(%s, %s, %s)" % (
  206. expr.__class__.__name__,
  207. self._print(expr.coord_sys),
  208. self._print(expr.index),
  209. ', '.join(attr)
  210. )
  211. def _print_Predicate(self, expr):
  212. return "Q.%s" % expr.name
  213. def _print_AppliedPredicate(self, expr):
  214. # will be changed to just expr.args when args overriding is removed
  215. args = expr._args
  216. return "%s(%s)" % (expr.__class__.__name__, self.reprify(args, ", "))
  217. def _print_str(self, expr):
  218. return repr(expr)
  219. def _print_tuple(self, expr):
  220. if len(expr) == 1:
  221. return "(%s,)" % self._print(expr[0])
  222. else:
  223. return "(%s)" % self.reprify(expr, ", ")
  224. def _print_WildFunction(self, expr):
  225. return "%s('%s')" % (expr.__class__.__name__, expr.name)
  226. def _print_AlgebraicNumber(self, expr):
  227. return "%s(%s, %s)" % (expr.__class__.__name__,
  228. self._print(expr.root), self._print(expr.coeffs()))
  229. def _print_PolyRing(self, ring):
  230. return "%s(%s, %s, %s)" % (ring.__class__.__name__,
  231. self._print(ring.symbols), self._print(ring.domain), self._print(ring.order))
  232. def _print_FracField(self, field):
  233. return "%s(%s, %s, %s)" % (field.__class__.__name__,
  234. self._print(field.symbols), self._print(field.domain), self._print(field.order))
  235. def _print_PolyElement(self, poly):
  236. terms = list(poly.terms())
  237. terms.sort(key=poly.ring.order, reverse=True)
  238. return "%s(%s, %s)" % (poly.__class__.__name__, self._print(poly.ring), self._print(terms))
  239. def _print_FracElement(self, frac):
  240. numer_terms = list(frac.numer.terms())
  241. numer_terms.sort(key=frac.field.order, reverse=True)
  242. denom_terms = list(frac.denom.terms())
  243. denom_terms.sort(key=frac.field.order, reverse=True)
  244. numer = self._print(numer_terms)
  245. denom = self._print(denom_terms)
  246. return "%s(%s, %s, %s)" % (frac.__class__.__name__, self._print(frac.field), numer, denom)
  247. def _print_FractionField(self, domain):
  248. cls = domain.__class__.__name__
  249. field = self._print(domain.field)
  250. return "%s(%s)" % (cls, field)
  251. def _print_PolynomialRingBase(self, ring):
  252. cls = ring.__class__.__name__
  253. dom = self._print(ring.domain)
  254. gens = ', '.join(map(self._print, ring.gens))
  255. order = str(ring.order)
  256. if order != ring.default_order:
  257. orderstr = ", order=" + order
  258. else:
  259. orderstr = ""
  260. return "%s(%s, %s%s)" % (cls, dom, gens, orderstr)
  261. def _print_DMP(self, p):
  262. cls = p.__class__.__name__
  263. rep = self._print(p.rep)
  264. dom = self._print(p.dom)
  265. if p.ring is not None:
  266. ringstr = ", ring=" + self._print(p.ring)
  267. else:
  268. ringstr = ""
  269. return "%s(%s, %s%s)" % (cls, rep, dom, ringstr)
  270. def _print_MonogenicFiniteExtension(self, ext):
  271. # The expanded tree shown by srepr(ext.modulus)
  272. # is not practical.
  273. return "FiniteExtension(%s)" % str(ext.modulus)
  274. def _print_ExtensionElement(self, f):
  275. rep = self._print(f.rep)
  276. ext = self._print(f.ext)
  277. return "ExtElem(%s, %s)" % (rep, ext)
  278. @print_function(ReprPrinter)
  279. def srepr(expr, **settings):
  280. """return expr in repr form"""
  281. return ReprPrinter(settings).doprint(expr)