conventions.py 2.5 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788
  1. """
  2. A few practical conventions common to all printers.
  3. """
  4. import re
  5. from collections.abc import Iterable
  6. from sympy.core.function import Derivative
  7. _name_with_digits_p = re.compile(r'^([^\W\d_]+)(\d+)$', re.U)
  8. def split_super_sub(text):
  9. """Split a symbol name into a name, superscripts and subscripts
  10. The first part of the symbol name is considered to be its actual
  11. 'name', followed by super- and subscripts. Each superscript is
  12. preceded with a "^" character or by "__". Each subscript is preceded
  13. by a "_" character. The three return values are the actual name, a
  14. list with superscripts and a list with subscripts.
  15. Examples
  16. ========
  17. >>> from sympy.printing.conventions import split_super_sub
  18. >>> split_super_sub('a_x^1')
  19. ('a', ['1'], ['x'])
  20. >>> split_super_sub('var_sub1__sup_sub2')
  21. ('var', ['sup'], ['sub1', 'sub2'])
  22. """
  23. if not text:
  24. return text, [], []
  25. pos = 0
  26. name = None
  27. supers = []
  28. subs = []
  29. while pos < len(text):
  30. start = pos + 1
  31. if text[pos:pos + 2] == "__":
  32. start += 1
  33. pos_hat = text.find("^", start)
  34. if pos_hat < 0:
  35. pos_hat = len(text)
  36. pos_usc = text.find("_", start)
  37. if pos_usc < 0:
  38. pos_usc = len(text)
  39. pos_next = min(pos_hat, pos_usc)
  40. part = text[pos:pos_next]
  41. pos = pos_next
  42. if name is None:
  43. name = part
  44. elif part.startswith("^"):
  45. supers.append(part[1:])
  46. elif part.startswith("__"):
  47. supers.append(part[2:])
  48. elif part.startswith("_"):
  49. subs.append(part[1:])
  50. else:
  51. raise RuntimeError("This should never happen.")
  52. # Make a little exception when a name ends with digits, i.e. treat them
  53. # as a subscript too.
  54. m = _name_with_digits_p.match(name)
  55. if m:
  56. name, sub = m.groups()
  57. subs.insert(0, sub)
  58. return name, supers, subs
  59. def requires_partial(expr):
  60. """Return whether a partial derivative symbol is required for printing
  61. This requires checking how many free variables there are,
  62. filtering out the ones that are integers. Some expressions do not have
  63. free variables. In that case, check its variable list explicitly to
  64. get the context of the expression.
  65. """
  66. if isinstance(expr, Derivative):
  67. return requires_partial(expr.expr)
  68. if not isinstance(expr.free_symbols, Iterable):
  69. return len(set(expr.variables)) > 1
  70. return sum(not s.is_integer for s in expr.free_symbols) > 1