traversal.py 3.1 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495
  1. from sympy.core.basic import Basic
  2. from sympy.printing import pprint
  3. import random
  4. def interactive_traversal(expr):
  5. """Traverse a tree asking a user which branch to choose. """
  6. RED, BRED = '\033[0;31m', '\033[1;31m'
  7. GREEN, BGREEN = '\033[0;32m', '\033[1;32m'
  8. YELLOW, BYELLOW = '\033[0;33m', '\033[1;33m' # noqa
  9. BLUE, BBLUE = '\033[0;34m', '\033[1;34m' # noqa
  10. MAGENTA, BMAGENTA = '\033[0;35m', '\033[1;35m'# noqa
  11. CYAN, BCYAN = '\033[0;36m', '\033[1;36m' # noqa
  12. END = '\033[0m'
  13. def cprint(*args):
  14. print("".join(map(str, args)) + END)
  15. def _interactive_traversal(expr, stage):
  16. if stage > 0:
  17. print()
  18. cprint("Current expression (stage ", BYELLOW, stage, END, "):")
  19. print(BCYAN)
  20. pprint(expr)
  21. print(END)
  22. if isinstance(expr, Basic):
  23. if expr.is_Add:
  24. args = expr.as_ordered_terms()
  25. elif expr.is_Mul:
  26. args = expr.as_ordered_factors()
  27. else:
  28. args = expr.args
  29. elif hasattr(expr, "__iter__"):
  30. args = list(expr)
  31. else:
  32. return expr
  33. n_args = len(args)
  34. if not n_args:
  35. return expr
  36. for i, arg in enumerate(args):
  37. cprint(GREEN, "[", BGREEN, i, GREEN, "] ", BLUE, type(arg), END)
  38. pprint(arg)
  39. print()
  40. if n_args == 1:
  41. choices = '0'
  42. else:
  43. choices = '0-%d' % (n_args - 1)
  44. try:
  45. choice = input("Your choice [%s,f,l,r,d,?]: " % choices)
  46. except EOFError:
  47. result = expr
  48. print()
  49. else:
  50. if choice == '?':
  51. cprint(RED, "%s - select subexpression with the given index" %
  52. choices)
  53. cprint(RED, "f - select the first subexpression")
  54. cprint(RED, "l - select the last subexpression")
  55. cprint(RED, "r - select a random subexpression")
  56. cprint(RED, "d - done\n")
  57. result = _interactive_traversal(expr, stage)
  58. elif choice in ('d', ''):
  59. result = expr
  60. elif choice == 'f':
  61. result = _interactive_traversal(args[0], stage + 1)
  62. elif choice == 'l':
  63. result = _interactive_traversal(args[-1], stage + 1)
  64. elif choice == 'r':
  65. result = _interactive_traversal(random.choice(args), stage + 1)
  66. else:
  67. try:
  68. choice = int(choice)
  69. except ValueError:
  70. cprint(BRED,
  71. "Choice must be a number in %s range\n" % choices)
  72. result = _interactive_traversal(expr, stage)
  73. else:
  74. if choice < 0 or choice >= n_args:
  75. cprint(BRED, "Choice must be in %s range\n" % choices)
  76. result = _interactive_traversal(expr, stage)
  77. else:
  78. result = _interactive_traversal(args[choice], stage + 1)
  79. return result
  80. return _interactive_traversal(expr, 0)