algorithms.py 4.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151
  1. from sympy.core.containers import Tuple
  2. from sympy.core.numbers import oo
  3. from sympy.core.relational import (Gt, Lt)
  4. from sympy.core.symbol import (Dummy, Symbol)
  5. from sympy.functions.elementary.complexes import Abs
  6. from sympy.logic.boolalg import And
  7. from sympy.codegen.ast import (
  8. Assignment, AddAugmentedAssignment, CodeBlock, Declaration, FunctionDefinition,
  9. Print, Return, Scope, While, Variable, Pointer, real
  10. )
  11. """ This module collects functions for constructing ASTs representing algorithms. """
  12. def newtons_method(expr, wrt, atol=1e-12, delta=None, debug=False,
  13. itermax=None, counter=None):
  14. """ Generates an AST for Newton-Raphson method (a root-finding algorithm).
  15. Explanation
  16. ===========
  17. Returns an abstract syntax tree (AST) based on ``sympy.codegen.ast`` for Netwon's
  18. method of root-finding.
  19. Parameters
  20. ==========
  21. expr : expression
  22. wrt : Symbol
  23. With respect to, i.e. what is the variable.
  24. atol : number or expr
  25. Absolute tolerance (stopping criterion)
  26. delta : Symbol
  27. Will be a ``Dummy`` if ``None``.
  28. debug : bool
  29. Whether to print convergence information during iterations
  30. itermax : number or expr
  31. Maximum number of iterations.
  32. counter : Symbol
  33. Will be a ``Dummy`` if ``None``.
  34. Examples
  35. ========
  36. >>> from sympy import symbols, cos
  37. >>> from sympy.codegen.ast import Assignment
  38. >>> from sympy.codegen.algorithms import newtons_method
  39. >>> x, dx, atol = symbols('x dx atol')
  40. >>> expr = cos(x) - x**3
  41. >>> algo = newtons_method(expr, x, atol, dx)
  42. >>> algo.has(Assignment(dx, -expr/expr.diff(x)))
  43. True
  44. References
  45. ==========
  46. .. [1] https://en.wikipedia.org/wiki/Newton%27s_method
  47. """
  48. if delta is None:
  49. delta = Dummy()
  50. Wrapper = Scope
  51. name_d = 'delta'
  52. else:
  53. Wrapper = lambda x: x
  54. name_d = delta.name
  55. delta_expr = -expr/expr.diff(wrt)
  56. whl_bdy = [Assignment(delta, delta_expr), AddAugmentedAssignment(wrt, delta)]
  57. if debug:
  58. prnt = Print([wrt, delta], r"{}=%12.5g {}=%12.5g\n".format(wrt.name, name_d))
  59. whl_bdy = [whl_bdy[0], prnt] + whl_bdy[1:]
  60. req = Gt(Abs(delta), atol)
  61. declars = [Declaration(Variable(delta, type=real, value=oo))]
  62. if itermax is not None:
  63. counter = counter or Dummy(integer=True)
  64. v_counter = Variable.deduced(counter, 0)
  65. declars.append(Declaration(v_counter))
  66. whl_bdy.append(AddAugmentedAssignment(counter, 1))
  67. req = And(req, Lt(counter, itermax))
  68. whl = While(req, CodeBlock(*whl_bdy))
  69. blck = declars + [whl]
  70. return Wrapper(CodeBlock(*blck))
  71. def _symbol_of(arg):
  72. if isinstance(arg, Declaration):
  73. arg = arg.variable.symbol
  74. elif isinstance(arg, Variable):
  75. arg = arg.symbol
  76. return arg
  77. def newtons_method_function(expr, wrt, params=None, func_name="newton", attrs=Tuple(), *, delta=None, **kwargs):
  78. """ Generates an AST for a function implementing the Newton-Raphson method.
  79. Parameters
  80. ==========
  81. expr : expression
  82. wrt : Symbol
  83. With respect to, i.e. what is the variable
  84. params : iterable of symbols
  85. Symbols appearing in expr that are taken as constants during the iterations
  86. (these will be accepted as parameters to the generated function).
  87. func_name : str
  88. Name of the generated function.
  89. attrs : Tuple
  90. Attribute instances passed as ``attrs`` to ``FunctionDefinition``.
  91. \\*\\*kwargs :
  92. Keyword arguments passed to :func:`sympy.codegen.algorithms.newtons_method`.
  93. Examples
  94. ========
  95. >>> from sympy import symbols, cos
  96. >>> from sympy.codegen.algorithms import newtons_method_function
  97. >>> from sympy.codegen.pyutils import render_as_module
  98. >>> x = symbols('x')
  99. >>> expr = cos(x) - x**3
  100. >>> func = newtons_method_function(expr, x)
  101. >>> py_mod = render_as_module(func) # source code as string
  102. >>> namespace = {}
  103. >>> exec(py_mod, namespace, namespace)
  104. >>> res = eval('newton(0.5)', namespace)
  105. >>> abs(res - 0.865474033102) < 1e-12
  106. True
  107. See Also
  108. ========
  109. sympy.codegen.algorithms.newtons_method
  110. """
  111. if params is None:
  112. params = (wrt,)
  113. pointer_subs = {p.symbol: Symbol('(*%s)' % p.symbol.name)
  114. for p in params if isinstance(p, Pointer)}
  115. if delta is None:
  116. delta = Symbol('d_' + wrt.name)
  117. if expr.has(delta):
  118. delta = None # will use Dummy
  119. algo = newtons_method(expr, wrt, delta=delta, **kwargs).xreplace(pointer_subs)
  120. if isinstance(algo, Scope):
  121. algo = algo.body
  122. not_in_params = expr.free_symbols.difference({_symbol_of(p) for p in params})
  123. if not_in_params:
  124. raise ValueError("Missing symbols in params: %s" % ', '.join(map(str, not_in_params)))
  125. declars = tuple(Variable(p, real) for p in params)
  126. body = CodeBlock(algo, Return(wrt))
  127. return FunctionDefinition(real, func_name, declars, body, attrs=attrs)