experimental_lambdify.py 22 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643
  1. """ rewrite of lambdify - This stuff is not stable at all.
  2. It is for internal use in the new plotting module.
  3. It may (will! see the Q'n'A in the source) be rewritten.
  4. It's completely self contained. Especially it does not use lambdarepr.
  5. It does not aim to replace the current lambdify. Most importantly it will never
  6. ever support anything else than SymPy expressions (no Matrices, dictionaries
  7. and so on).
  8. """
  9. import re
  10. from sympy.core.numbers import (I, NumberSymbol, oo, zoo)
  11. from sympy.core.symbol import Symbol
  12. from sympy.utilities.iterables import numbered_symbols
  13. # We parse the expression string into a tree that identifies functions. Then
  14. # we translate the names of the functions and we translate also some strings
  15. # that are not names of functions (all this according to translation
  16. # dictionaries).
  17. # If the translation goes to another module (like numpy) the
  18. # module is imported and 'func' is translated to 'module.func'.
  19. # If a function can not be translated, the inner nodes of that part of the
  20. # tree are not translated. So if we have Integral(sqrt(x)), sqrt is not
  21. # translated to np.sqrt and the Integral does not crash.
  22. # A namespace for all this is generated by crawling the (func, args) tree of
  23. # the expression. The creation of this namespace involves many ugly
  24. # workarounds.
  25. # The namespace consists of all the names needed for the SymPy expression and
  26. # all the name of modules used for translation. Those modules are imported only
  27. # as a name (import numpy as np) in order to keep the namespace small and
  28. # manageable.
  29. # Please, if there is a bug, do not try to fix it here! Rewrite this by using
  30. # the method proposed in the last Q'n'A below. That way the new function will
  31. # work just as well, be just as simple, but it wont need any new workarounds.
  32. # If you insist on fixing it here, look at the workarounds in the function
  33. # sympy_expression_namespace and in lambdify.
  34. # Q: Why are you not using Python abstract syntax tree?
  35. # A: Because it is more complicated and not much more powerful in this case.
  36. # Q: What if I have Symbol('sin') or g=Function('f')?
  37. # A: You will break the algorithm. We should use srepr to defend against this?
  38. # The problem with Symbol('sin') is that it will be printed as 'sin'. The
  39. # parser will distinguish it from the function 'sin' because functions are
  40. # detected thanks to the opening parenthesis, but the lambda expression won't
  41. # understand the difference if we have also the sin function.
  42. # The solution (complicated) is to use srepr and maybe ast.
  43. # The problem with the g=Function('f') is that it will be printed as 'f' but in
  44. # the global namespace we have only 'g'. But as the same printer is used in the
  45. # constructor of the namespace there will be no problem.
  46. # Q: What if some of the printers are not printing as expected?
  47. # A: The algorithm wont work. You must use srepr for those cases. But even
  48. # srepr may not print well. All problems with printers should be considered
  49. # bugs.
  50. # Q: What about _imp_ functions?
  51. # A: Those are taken care for by evalf. A special case treatment will work
  52. # faster but it's not worth the code complexity.
  53. # Q: Will ast fix all possible problems?
  54. # A: No. You will always have to use some printer. Even srepr may not work in
  55. # some cases. But if the printer does not work, that should be considered a
  56. # bug.
  57. # Q: Is there same way to fix all possible problems?
  58. # A: Probably by constructing our strings ourself by traversing the (func,
  59. # args) tree and creating the namespace at the same time. That actually sounds
  60. # good.
  61. from sympy.external import import_module
  62. import warnings
  63. #TODO debugging output
  64. class vectorized_lambdify:
  65. """ Return a sufficiently smart, vectorized and lambdified function.
  66. Returns only reals.
  67. Explanation
  68. ===========
  69. This function uses experimental_lambdify to created a lambdified
  70. expression ready to be used with numpy. Many of the functions in SymPy
  71. are not implemented in numpy so in some cases we resort to Python cmath or
  72. even to evalf.
  73. The following translations are tried:
  74. only numpy complex
  75. - on errors raised by SymPy trying to work with ndarray:
  76. only Python cmath and then vectorize complex128
  77. When using Python cmath there is no need for evalf or float/complex
  78. because Python cmath calls those.
  79. This function never tries to mix numpy directly with evalf because numpy
  80. does not understand SymPy Float. If this is needed one can use the
  81. float_wrap_evalf/complex_wrap_evalf options of experimental_lambdify or
  82. better one can be explicit about the dtypes that numpy works with.
  83. Check numpy bug http://projects.scipy.org/numpy/ticket/1013 to know what
  84. types of errors to expect.
  85. """
  86. def __init__(self, args, expr):
  87. self.args = args
  88. self.expr = expr
  89. self.np = import_module('numpy')
  90. self.lambda_func_1 = experimental_lambdify(
  91. args, expr, use_np=True)
  92. self.vector_func_1 = self.lambda_func_1
  93. self.lambda_func_2 = experimental_lambdify(
  94. args, expr, use_python_cmath=True)
  95. self.vector_func_2 = self.np.vectorize(
  96. self.lambda_func_2, otypes=[complex])
  97. self.vector_func = self.vector_func_1
  98. self.failure = False
  99. def __call__(self, *args):
  100. np = self.np
  101. try:
  102. temp_args = (np.array(a, dtype=complex) for a in args)
  103. results = self.vector_func(*temp_args)
  104. results = np.ma.masked_where(
  105. np.abs(results.imag) > 1e-7 * np.abs(results),
  106. results.real, copy=False)
  107. return results
  108. except ValueError:
  109. if self.failure:
  110. raise
  111. self.failure = True
  112. self.vector_func = self.vector_func_2
  113. warnings.warn(
  114. 'The evaluation of the expression is problematic. '
  115. 'We are trying a failback method that may still work. '
  116. 'Please report this as a bug.')
  117. return self.__call__(*args)
  118. class lambdify:
  119. """Returns the lambdified function.
  120. Explanation
  121. ===========
  122. This function uses experimental_lambdify to create a lambdified
  123. expression. It uses cmath to lambdify the expression. If the function
  124. is not implemented in Python cmath, Python cmath calls evalf on those
  125. functions.
  126. """
  127. def __init__(self, args, expr):
  128. self.args = args
  129. self.expr = expr
  130. self.lambda_func_1 = experimental_lambdify(
  131. args, expr, use_python_cmath=True, use_evalf=True)
  132. self.lambda_func_2 = experimental_lambdify(
  133. args, expr, use_python_math=True, use_evalf=True)
  134. self.lambda_func_3 = experimental_lambdify(
  135. args, expr, use_evalf=True, complex_wrap_evalf=True)
  136. self.lambda_func = self.lambda_func_1
  137. self.failure = False
  138. def __call__(self, args):
  139. try:
  140. #The result can be sympy.Float. Hence wrap it with complex type.
  141. result = complex(self.lambda_func(args))
  142. if abs(result.imag) > 1e-7 * abs(result):
  143. return None
  144. return result.real
  145. except (ZeroDivisionError, OverflowError):
  146. return None
  147. except TypeError as e:
  148. if self.failure:
  149. raise e
  150. if self.lambda_func == self.lambda_func_1:
  151. self.lambda_func = self.lambda_func_2
  152. return self.__call__(args)
  153. self.failure = True
  154. self.lambda_func = self.lambda_func_3
  155. warnings.warn(
  156. 'The evaluation of the expression is problematic. '
  157. 'We are trying a failback method that may still work. '
  158. 'Please report this as a bug.', stacklevel=2)
  159. return self.__call__(args)
  160. def experimental_lambdify(*args, **kwargs):
  161. l = Lambdifier(*args, **kwargs)
  162. return l
  163. class Lambdifier:
  164. def __init__(self, args, expr, print_lambda=False, use_evalf=False,
  165. float_wrap_evalf=False, complex_wrap_evalf=False,
  166. use_np=False, use_python_math=False, use_python_cmath=False,
  167. use_interval=False):
  168. self.print_lambda = print_lambda
  169. self.use_evalf = use_evalf
  170. self.float_wrap_evalf = float_wrap_evalf
  171. self.complex_wrap_evalf = complex_wrap_evalf
  172. self.use_np = use_np
  173. self.use_python_math = use_python_math
  174. self.use_python_cmath = use_python_cmath
  175. self.use_interval = use_interval
  176. # Constructing the argument string
  177. # - check
  178. if not all(isinstance(a, Symbol) for a in args):
  179. raise ValueError('The arguments must be Symbols.')
  180. # - use numbered symbols
  181. syms = numbered_symbols(exclude=expr.free_symbols)
  182. newargs = [next(syms) for _ in args]
  183. expr = expr.xreplace(dict(zip(args, newargs)))
  184. argstr = ', '.join([str(a) for a in newargs])
  185. del syms, newargs, args
  186. # Constructing the translation dictionaries and making the translation
  187. self.dict_str = self.get_dict_str()
  188. self.dict_fun = self.get_dict_fun()
  189. exprstr = str(expr)
  190. newexpr = self.tree2str_translate(self.str2tree(exprstr))
  191. # Constructing the namespaces
  192. namespace = {}
  193. namespace.update(self.sympy_atoms_namespace(expr))
  194. namespace.update(self.sympy_expression_namespace(expr))
  195. # XXX Workaround
  196. # Ugly workaround because Pow(a,Half) prints as sqrt(a)
  197. # and sympy_expression_namespace can not catch it.
  198. from sympy.functions.elementary.miscellaneous import sqrt
  199. namespace.update({'sqrt': sqrt})
  200. namespace.update({'Eq': lambda x, y: x == y})
  201. namespace.update({'Ne': lambda x, y: x != y})
  202. # End workaround.
  203. if use_python_math:
  204. namespace.update({'math': __import__('math')})
  205. if use_python_cmath:
  206. namespace.update({'cmath': __import__('cmath')})
  207. if use_np:
  208. try:
  209. namespace.update({'np': __import__('numpy')})
  210. except ImportError:
  211. raise ImportError(
  212. 'experimental_lambdify failed to import numpy.')
  213. if use_interval:
  214. namespace.update({'imath': __import__(
  215. 'sympy.plotting.intervalmath', fromlist=['intervalmath'])})
  216. namespace.update({'math': __import__('math')})
  217. # Construct the lambda
  218. if self.print_lambda:
  219. print(newexpr)
  220. eval_str = 'lambda %s : ( %s )' % (argstr, newexpr)
  221. self.eval_str = eval_str
  222. exec("MYNEWLAMBDA = %s" % eval_str, namespace)
  223. self.lambda_func = namespace['MYNEWLAMBDA']
  224. def __call__(self, *args, **kwargs):
  225. return self.lambda_func(*args, **kwargs)
  226. ##############################################################################
  227. # Dicts for translating from SymPy to other modules
  228. ##############################################################################
  229. ###
  230. # builtins
  231. ###
  232. # Functions with different names in builtins
  233. builtin_functions_different = {
  234. 'Min': 'min',
  235. 'Max': 'max',
  236. 'Abs': 'abs',
  237. }
  238. # Strings that should be translated
  239. builtin_not_functions = {
  240. 'I': '1j',
  241. # 'oo': '1e400',
  242. }
  243. ###
  244. # numpy
  245. ###
  246. # Functions that are the same in numpy
  247. numpy_functions_same = [
  248. 'sin', 'cos', 'tan', 'sinh', 'cosh', 'tanh', 'exp', 'log',
  249. 'sqrt', 'floor', 'conjugate',
  250. ]
  251. # Functions with different names in numpy
  252. numpy_functions_different = {
  253. "acos": "arccos",
  254. "acosh": "arccosh",
  255. "arg": "angle",
  256. "asin": "arcsin",
  257. "asinh": "arcsinh",
  258. "atan": "arctan",
  259. "atan2": "arctan2",
  260. "atanh": "arctanh",
  261. "ceiling": "ceil",
  262. "im": "imag",
  263. "ln": "log",
  264. "Max": "amax",
  265. "Min": "amin",
  266. "re": "real",
  267. "Abs": "abs",
  268. }
  269. # Strings that should be translated
  270. numpy_not_functions = {
  271. 'pi': 'np.pi',
  272. 'oo': 'np.inf',
  273. 'E': 'np.e',
  274. }
  275. ###
  276. # Python math
  277. ###
  278. # Functions that are the same in math
  279. math_functions_same = [
  280. 'sin', 'cos', 'tan', 'asin', 'acos', 'atan', 'atan2',
  281. 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
  282. 'exp', 'log', 'erf', 'sqrt', 'floor', 'factorial', 'gamma',
  283. ]
  284. # Functions with different names in math
  285. math_functions_different = {
  286. 'ceiling': 'ceil',
  287. 'ln': 'log',
  288. 'loggamma': 'lgamma'
  289. }
  290. # Strings that should be translated
  291. math_not_functions = {
  292. 'pi': 'math.pi',
  293. 'E': 'math.e',
  294. }
  295. ###
  296. # Python cmath
  297. ###
  298. # Functions that are the same in cmath
  299. cmath_functions_same = [
  300. 'sin', 'cos', 'tan', 'asin', 'acos', 'atan',
  301. 'sinh', 'cosh', 'tanh', 'asinh', 'acosh', 'atanh',
  302. 'exp', 'log', 'sqrt',
  303. ]
  304. # Functions with different names in cmath
  305. cmath_functions_different = {
  306. 'ln': 'log',
  307. 'arg': 'phase',
  308. }
  309. # Strings that should be translated
  310. cmath_not_functions = {
  311. 'pi': 'cmath.pi',
  312. 'E': 'cmath.e',
  313. }
  314. ###
  315. # intervalmath
  316. ###
  317. interval_not_functions = {
  318. 'pi': 'math.pi',
  319. 'E': 'math.e'
  320. }
  321. interval_functions_same = [
  322. 'sin', 'cos', 'exp', 'tan', 'atan', 'log',
  323. 'sqrt', 'cosh', 'sinh', 'tanh', 'floor',
  324. 'acos', 'asin', 'acosh', 'asinh', 'atanh',
  325. 'Abs', 'And', 'Or'
  326. ]
  327. interval_functions_different = {
  328. 'Min': 'imin',
  329. 'Max': 'imax',
  330. 'ceiling': 'ceil',
  331. }
  332. ###
  333. # mpmath, etc
  334. ###
  335. #TODO
  336. ###
  337. # Create the final ordered tuples of dictionaries
  338. ###
  339. # For strings
  340. def get_dict_str(self):
  341. dict_str = dict(self.builtin_not_functions)
  342. if self.use_np:
  343. dict_str.update(self.numpy_not_functions)
  344. if self.use_python_math:
  345. dict_str.update(self.math_not_functions)
  346. if self.use_python_cmath:
  347. dict_str.update(self.cmath_not_functions)
  348. if self.use_interval:
  349. dict_str.update(self.interval_not_functions)
  350. return dict_str
  351. # For functions
  352. def get_dict_fun(self):
  353. dict_fun = dict(self.builtin_functions_different)
  354. if self.use_np:
  355. for s in self.numpy_functions_same:
  356. dict_fun[s] = 'np.' + s
  357. for k, v in self.numpy_functions_different.items():
  358. dict_fun[k] = 'np.' + v
  359. if self.use_python_math:
  360. for s in self.math_functions_same:
  361. dict_fun[s] = 'math.' + s
  362. for k, v in self.math_functions_different.items():
  363. dict_fun[k] = 'math.' + v
  364. if self.use_python_cmath:
  365. for s in self.cmath_functions_same:
  366. dict_fun[s] = 'cmath.' + s
  367. for k, v in self.cmath_functions_different.items():
  368. dict_fun[k] = 'cmath.' + v
  369. if self.use_interval:
  370. for s in self.interval_functions_same:
  371. dict_fun[s] = 'imath.' + s
  372. for k, v in self.interval_functions_different.items():
  373. dict_fun[k] = 'imath.' + v
  374. return dict_fun
  375. ##############################################################################
  376. # The translator functions, tree parsers, etc.
  377. ##############################################################################
  378. def str2tree(self, exprstr):
  379. """Converts an expression string to a tree.
  380. Explanation
  381. ===========
  382. Functions are represented by ('func_name(', tree_of_arguments).
  383. Other expressions are (head_string, mid_tree, tail_str).
  384. Expressions that do not contain functions are directly returned.
  385. Examples
  386. ========
  387. >>> from sympy.abc import x, y, z
  388. >>> from sympy import Integral, sin
  389. >>> from sympy.plotting.experimental_lambdify import Lambdifier
  390. >>> str2tree = Lambdifier([x], x).str2tree
  391. >>> str2tree(str(Integral(x, (x, 1, y))))
  392. ('', ('Integral(', 'x, (x, 1, y)'), ')')
  393. >>> str2tree(str(x+y))
  394. 'x + y'
  395. >>> str2tree(str(x+y*sin(z)+1))
  396. ('x + y*', ('sin(', 'z'), ') + 1')
  397. >>> str2tree('sin(y*(y + 1.1) + (sin(y)))')
  398. ('', ('sin(', ('y*(y + 1.1) + (', ('sin(', 'y'), '))')), ')')
  399. """
  400. #matches the first 'function_name('
  401. first_par = re.search(r'(\w+\()', exprstr)
  402. if first_par is None:
  403. return exprstr
  404. else:
  405. start = first_par.start()
  406. end = first_par.end()
  407. head = exprstr[:start]
  408. func = exprstr[start:end]
  409. tail = exprstr[end:]
  410. count = 0
  411. for i, c in enumerate(tail):
  412. if c == '(':
  413. count += 1
  414. elif c == ')':
  415. count -= 1
  416. if count == -1:
  417. break
  418. func_tail = self.str2tree(tail[:i])
  419. tail = self.str2tree(tail[i:])
  420. return (head, (func, func_tail), tail)
  421. @classmethod
  422. def tree2str(cls, tree):
  423. """Converts a tree to string without translations.
  424. Examples
  425. ========
  426. >>> from sympy.abc import x, y, z
  427. >>> from sympy import sin
  428. >>> from sympy.plotting.experimental_lambdify import Lambdifier
  429. >>> str2tree = Lambdifier([x], x).str2tree
  430. >>> tree2str = Lambdifier([x], x).tree2str
  431. >>> tree2str(str2tree(str(x+y*sin(z)+1)))
  432. 'x + y*sin(z) + 1'
  433. """
  434. if isinstance(tree, str):
  435. return tree
  436. else:
  437. return ''.join(map(cls.tree2str, tree))
  438. def tree2str_translate(self, tree):
  439. """Converts a tree to string with translations.
  440. Explanation
  441. ===========
  442. Function names are translated by translate_func.
  443. Other strings are translated by translate_str.
  444. """
  445. if isinstance(tree, str):
  446. return self.translate_str(tree)
  447. elif isinstance(tree, tuple) and len(tree) == 2:
  448. return self.translate_func(tree[0][:-1], tree[1])
  449. else:
  450. return ''.join([self.tree2str_translate(t) for t in tree])
  451. def translate_str(self, estr):
  452. """Translate substrings of estr using in order the dictionaries in
  453. dict_tuple_str."""
  454. for pattern, repl in self.dict_str.items():
  455. estr = re.sub(pattern, repl, estr)
  456. return estr
  457. def translate_func(self, func_name, argtree):
  458. """Translate function names and the tree of arguments.
  459. Explanation
  460. ===========
  461. If the function name is not in the dictionaries of dict_tuple_fun then the
  462. function is surrounded by a float((...).evalf()).
  463. The use of float is necessary as np.<function>(sympy.Float(..)) raises an
  464. error."""
  465. if func_name in self.dict_fun:
  466. new_name = self.dict_fun[func_name]
  467. argstr = self.tree2str_translate(argtree)
  468. return new_name + '(' + argstr
  469. elif func_name in ['Eq', 'Ne']:
  470. op = {'Eq': '==', 'Ne': '!='}
  471. return "(lambda x, y: x {} y)({}".format(op[func_name], self.tree2str_translate(argtree))
  472. else:
  473. template = '(%s(%s)).evalf(' if self.use_evalf else '%s(%s'
  474. if self.float_wrap_evalf:
  475. template = 'float(%s)' % template
  476. elif self.complex_wrap_evalf:
  477. template = 'complex(%s)' % template
  478. # Wrapping should only happen on the outermost expression, which
  479. # is the only thing we know will be a number.
  480. float_wrap_evalf = self.float_wrap_evalf
  481. complex_wrap_evalf = self.complex_wrap_evalf
  482. self.float_wrap_evalf = False
  483. self.complex_wrap_evalf = False
  484. ret = template % (func_name, self.tree2str_translate(argtree))
  485. self.float_wrap_evalf = float_wrap_evalf
  486. self.complex_wrap_evalf = complex_wrap_evalf
  487. return ret
  488. ##############################################################################
  489. # The namespace constructors
  490. ##############################################################################
  491. @classmethod
  492. def sympy_expression_namespace(cls, expr):
  493. """Traverses the (func, args) tree of an expression and creates a SymPy
  494. namespace. All other modules are imported only as a module name. That way
  495. the namespace is not polluted and rests quite small. It probably causes much
  496. more variable lookups and so it takes more time, but there are no tests on
  497. that for the moment."""
  498. if expr is None:
  499. return {}
  500. else:
  501. funcname = str(expr.func)
  502. # XXX Workaround
  503. # Here we add an ugly workaround because str(func(x))
  504. # is not always the same as str(func). Eg
  505. # >>> str(Integral(x))
  506. # "Integral(x)"
  507. # >>> str(Integral)
  508. # "<class 'sympy.integrals.integrals.Integral'>"
  509. # >>> str(sqrt(x))
  510. # "sqrt(x)"
  511. # >>> str(sqrt)
  512. # "<function sqrt at 0x3d92de8>"
  513. # >>> str(sin(x))
  514. # "sin(x)"
  515. # >>> str(sin)
  516. # "sin"
  517. # Either one of those can be used but not all at the same time.
  518. # The code considers the sin example as the right one.
  519. regexlist = [
  520. r'<class \'sympy[\w.]*?.([\w]*)\'>$',
  521. # the example Integral
  522. r'<function ([\w]*) at 0x[\w]*>$', # the example sqrt
  523. ]
  524. for r in regexlist:
  525. m = re.match(r, funcname)
  526. if m is not None:
  527. funcname = m.groups()[0]
  528. # End of the workaround
  529. # XXX debug: print funcname
  530. args_dict = {}
  531. for a in expr.args:
  532. if (isinstance(a, Symbol) or
  533. isinstance(a, NumberSymbol) or
  534. a in [I, zoo, oo]):
  535. continue
  536. else:
  537. args_dict.update(cls.sympy_expression_namespace(a))
  538. args_dict.update({funcname: expr.func})
  539. return args_dict
  540. @staticmethod
  541. def sympy_atoms_namespace(expr):
  542. """For no real reason this function is separated from
  543. sympy_expression_namespace. It can be moved to it."""
  544. atoms = expr.atoms(Symbol, NumberSymbol, I, zoo, oo)
  545. d = {}
  546. for a in atoms:
  547. # XXX debug: print 'atom:' + str(a)
  548. d[str(a)] = a
  549. return d