ast_parser.py 2.6 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879
  1. """
  2. This module implements the functionality to take any Python expression as a
  3. string and fix all numbers and other things before evaluating it,
  4. thus
  5. 1/2
  6. returns
  7. Integer(1)/Integer(2)
  8. We use the ast module for this. It is well documented at docs.python.org.
  9. Some tips to understand how this works: use dump() to get a nice
  10. representation of any node. Then write a string of what you want to get,
  11. e.g. "Integer(1)", parse it, dump it and you'll see that you need to do
  12. "Call(Name('Integer', Load()), [node], [], None, None)". You do not need
  13. to bother with lineno and col_offset, just call fix_missing_locations()
  14. before returning the node.
  15. """
  16. from sympy.core.basic import Basic
  17. from sympy.core.sympify import SympifyError
  18. from ast import parse, NodeTransformer, Call, Name, Load, \
  19. fix_missing_locations, Str, Tuple
  20. class Transform(NodeTransformer):
  21. def __init__(self, local_dict, global_dict):
  22. NodeTransformer.__init__(self)
  23. self.local_dict = local_dict
  24. self.global_dict = global_dict
  25. def visit_Num(self, node):
  26. if isinstance(node.n, int):
  27. return fix_missing_locations(Call(func=Name('Integer', Load()),
  28. args=[node], keywords=[]))
  29. elif isinstance(node.n, float):
  30. return fix_missing_locations(Call(func=Name('Float', Load()),
  31. args=[node], keywords=[]))
  32. return node
  33. def visit_Name(self, node):
  34. if node.id in self.local_dict:
  35. return node
  36. elif node.id in self.global_dict:
  37. name_obj = self.global_dict[node.id]
  38. if isinstance(name_obj, (Basic, type)) or callable(name_obj):
  39. return node
  40. elif node.id in ['True', 'False']:
  41. return node
  42. return fix_missing_locations(Call(func=Name('Symbol', Load()),
  43. args=[Str(node.id)], keywords=[]))
  44. def visit_Lambda(self, node):
  45. args = [self.visit(arg) for arg in node.args.args]
  46. body = self.visit(node.body)
  47. n = Call(func=Name('Lambda', Load()),
  48. args=[Tuple(args, Load()), body], keywords=[])
  49. return fix_missing_locations(n)
  50. def parse_expr(s, local_dict):
  51. """
  52. Converts the string "s" to a SymPy expression, in local_dict.
  53. It converts all numbers to Integers before feeding it to Python and
  54. automatically creates Symbols.
  55. """
  56. global_dict = {}
  57. exec('from sympy import *', global_dict)
  58. try:
  59. a = parse(s.strip(), mode="eval")
  60. except SyntaxError:
  61. raise SympifyError("Cannot parse %s." % repr(s))
  62. a = Transform(local_dict, global_dict).visit(a)
  63. e = compile(a, "<string>", "eval")
  64. return eval(e, global_dict, local_dict)