odes.py 9.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288
  1. from bisect import bisect
  2. from ..libmp.backend import xrange
  3. class ODEMethods(object):
  4. pass
  5. def ode_taylor(ctx, derivs, x0, y0, tol_prec, n):
  6. h = tol = ctx.ldexp(1, -tol_prec)
  7. dim = len(y0)
  8. xs = [x0]
  9. ys = [y0]
  10. x = x0
  11. y = y0
  12. orig = ctx.prec
  13. try:
  14. ctx.prec = orig*(1+n)
  15. # Use n steps with Euler's method to get
  16. # evaluation points for derivatives
  17. for i in range(n):
  18. fxy = derivs(x, y)
  19. y = [y[i]+h*fxy[i] for i in xrange(len(y))]
  20. x += h
  21. xs.append(x)
  22. ys.append(y)
  23. # Compute derivatives
  24. ser = [[] for d in range(dim)]
  25. for j in range(n+1):
  26. s = [0]*dim
  27. b = (-1) ** (j & 1)
  28. k = 1
  29. for i in range(j+1):
  30. for d in range(dim):
  31. s[d] += b * ys[i][d]
  32. b = (b * (j-k+1)) // (-k)
  33. k += 1
  34. scale = h**(-j) / ctx.fac(j)
  35. for d in range(dim):
  36. s[d] = s[d] * scale
  37. ser[d].append(s[d])
  38. finally:
  39. ctx.prec = orig
  40. # Estimate radius for which we can get full accuracy.
  41. # XXX: do this right for zeros
  42. radius = ctx.one
  43. for ts in ser:
  44. if ts[-1]:
  45. radius = min(radius, ctx.nthroot(tol/abs(ts[-1]), n))
  46. radius /= 2 # XXX
  47. return ser, x0+radius
  48. def odefun(ctx, F, x0, y0, tol=None, degree=None, method='taylor', verbose=False):
  49. r"""
  50. Returns a function `y(x) = [y_0(x), y_1(x), \ldots, y_n(x)]`
  51. that is a numerical solution of the `n+1`-dimensional first-order
  52. ordinary differential equation (ODE) system
  53. .. math ::
  54. y_0'(x) = F_0(x, [y_0(x), y_1(x), \ldots, y_n(x)])
  55. y_1'(x) = F_1(x, [y_0(x), y_1(x), \ldots, y_n(x)])
  56. \vdots
  57. y_n'(x) = F_n(x, [y_0(x), y_1(x), \ldots, y_n(x)])
  58. The derivatives are specified by the vector-valued function
  59. *F* that evaluates
  60. `[y_0', \ldots, y_n'] = F(x, [y_0, \ldots, y_n])`.
  61. The initial point `x_0` is specified by the scalar argument *x0*,
  62. and the initial value `y(x_0) = [y_0(x_0), \ldots, y_n(x_0)]` is
  63. specified by the vector argument *y0*.
  64. For convenience, if the system is one-dimensional, you may optionally
  65. provide just a scalar value for *y0*. In this case, *F* should accept
  66. a scalar *y* argument and return a scalar. The solution function
  67. *y* will return scalar values instead of length-1 vectors.
  68. Evaluation of the solution function `y(x)` is permitted
  69. for any `x \ge x_0`.
  70. A high-order ODE can be solved by transforming it into first-order
  71. vector form. This transformation is described in standard texts
  72. on ODEs. Examples will also be given below.
  73. **Options, speed and accuracy**
  74. By default, :func:`~mpmath.odefun` uses a high-order Taylor series
  75. method. For reasonably well-behaved problems, the solution will
  76. be fully accurate to within the working precision. Note that
  77. *F* must be possible to evaluate to very high precision
  78. for the generation of Taylor series to work.
  79. To get a faster but less accurate solution, you can set a large
  80. value for *tol* (which defaults roughly to *eps*). If you just
  81. want to plot the solution or perform a basic simulation,
  82. *tol = 0.01* is likely sufficient.
  83. The *degree* argument controls the degree of the solver (with
  84. *method='taylor'*, this is the degree of the Taylor series
  85. expansion). A higher degree means that a longer step can be taken
  86. before a new local solution must be generated from *F*,
  87. meaning that fewer steps are required to get from `x_0` to a given
  88. `x_1`. On the other hand, a higher degree also means that each
  89. local solution becomes more expensive (i.e., more evaluations of
  90. *F* are required per step, and at higher precision).
  91. The optimal setting therefore involves a tradeoff. Generally,
  92. decreasing the *degree* for Taylor series is likely to give faster
  93. solution at low precision, while increasing is likely to be better
  94. at higher precision.
  95. The function
  96. object returned by :func:`~mpmath.odefun` caches the solutions at all step
  97. points and uses polynomial interpolation between step points.
  98. Therefore, once `y(x_1)` has been evaluated for some `x_1`,
  99. `y(x)` can be evaluated very quickly for any `x_0 \le x \le x_1`.
  100. and continuing the evaluation up to `x_2 > x_1` is also fast.
  101. **Examples of first-order ODEs**
  102. We will solve the standard test problem `y'(x) = y(x), y(0) = 1`
  103. which has explicit solution `y(x) = \exp(x)`::
  104. >>> from mpmath import *
  105. >>> mp.dps = 15; mp.pretty = True
  106. >>> f = odefun(lambda x, y: y, 0, 1)
  107. >>> for x in [0, 1, 2.5]:
  108. ... print((f(x), exp(x)))
  109. ...
  110. (1.0, 1.0)
  111. (2.71828182845905, 2.71828182845905)
  112. (12.1824939607035, 12.1824939607035)
  113. The solution with high precision::
  114. >>> mp.dps = 50
  115. >>> f = odefun(lambda x, y: y, 0, 1)
  116. >>> f(1)
  117. 2.7182818284590452353602874713526624977572470937
  118. >>> exp(1)
  119. 2.7182818284590452353602874713526624977572470937
  120. Using the more general vectorized form, the test problem
  121. can be input as (note that *f* returns a 1-element vector)::
  122. >>> mp.dps = 15
  123. >>> f = odefun(lambda x, y: [y[0]], 0, [1])
  124. >>> f(1)
  125. [2.71828182845905]
  126. :func:`~mpmath.odefun` can solve nonlinear ODEs, which are generally
  127. impossible (and at best difficult) to solve analytically. As
  128. an example of a nonlinear ODE, we will solve `y'(x) = x \sin(y(x))`
  129. for `y(0) = \pi/2`. An exact solution happens to be known
  130. for this problem, and is given by
  131. `y(x) = 2 \tan^{-1}\left(\exp\left(x^2/2\right)\right)`::
  132. >>> f = odefun(lambda x, y: x*sin(y), 0, pi/2)
  133. >>> for x in [2, 5, 10]:
  134. ... print((f(x), 2*atan(exp(mpf(x)**2/2))))
  135. ...
  136. (2.87255666284091, 2.87255666284091)
  137. (3.14158520028345, 3.14158520028345)
  138. (3.14159265358979, 3.14159265358979)
  139. If `F` is independent of `y`, an ODE can be solved using direct
  140. integration. We can therefore obtain a reference solution with
  141. :func:`~mpmath.quad`::
  142. >>> f = lambda x: (1+x**2)/(1+x**3)
  143. >>> g = odefun(lambda x, y: f(x), pi, 0)
  144. >>> g(2*pi)
  145. 0.72128263801696
  146. >>> quad(f, [pi, 2*pi])
  147. 0.72128263801696
  148. **Examples of second-order ODEs**
  149. We will solve the harmonic oscillator equation `y''(x) + y(x) = 0`.
  150. To do this, we introduce the helper functions `y_0 = y, y_1 = y_0'`
  151. whereby the original equation can be written as `y_1' + y_0' = 0`. Put
  152. together, we get the first-order, two-dimensional vector ODE
  153. .. math ::
  154. \begin{cases}
  155. y_0' = y_1 \\
  156. y_1' = -y_0
  157. \end{cases}
  158. To get a well-defined IVP, we need two initial values. With
  159. `y(0) = y_0(0) = 1` and `-y'(0) = y_1(0) = 0`, the problem will of
  160. course be solved by `y(x) = y_0(x) = \cos(x)` and
  161. `-y'(x) = y_1(x) = \sin(x)`. We check this::
  162. >>> f = odefun(lambda x, y: [-y[1], y[0]], 0, [1, 0])
  163. >>> for x in [0, 1, 2.5, 10]:
  164. ... nprint(f(x), 15)
  165. ... nprint([cos(x), sin(x)], 15)
  166. ... print("---")
  167. ...
  168. [1.0, 0.0]
  169. [1.0, 0.0]
  170. ---
  171. [0.54030230586814, 0.841470984807897]
  172. [0.54030230586814, 0.841470984807897]
  173. ---
  174. [-0.801143615546934, 0.598472144103957]
  175. [-0.801143615546934, 0.598472144103957]
  176. ---
  177. [-0.839071529076452, -0.54402111088937]
  178. [-0.839071529076452, -0.54402111088937]
  179. ---
  180. Note that we get both the sine and the cosine solutions
  181. simultaneously.
  182. **TODO**
  183. * Better automatic choice of degree and step size
  184. * Make determination of Taylor series convergence radius
  185. more robust
  186. * Allow solution for `x < x_0`
  187. * Allow solution for complex `x`
  188. * Test for difficult (ill-conditioned) problems
  189. * Implement Runge-Kutta and other algorithms
  190. """
  191. if tol:
  192. tol_prec = int(-ctx.log(tol, 2))+10
  193. else:
  194. tol_prec = ctx.prec+10
  195. degree = degree or (3 + int(3*ctx.dps/2.))
  196. workprec = ctx.prec + 40
  197. try:
  198. len(y0)
  199. return_vector = True
  200. except TypeError:
  201. F_ = F
  202. F = lambda x, y: [F_(x, y[0])]
  203. y0 = [y0]
  204. return_vector = False
  205. ser, xb = ode_taylor(ctx, F, x0, y0, tol_prec, degree)
  206. series_boundaries = [x0, xb]
  207. series_data = [(ser, x0, xb)]
  208. # We will be working with vectors of Taylor series
  209. def mpolyval(ser, a):
  210. return [ctx.polyval(s[::-1], a) for s in ser]
  211. # Find nearest expansion point; compute if necessary
  212. def get_series(x):
  213. if x < x0:
  214. raise ValueError
  215. n = bisect(series_boundaries, x)
  216. if n < len(series_boundaries):
  217. return series_data[n-1]
  218. while 1:
  219. ser, xa, xb = series_data[-1]
  220. if verbose:
  221. print("Computing Taylor series for [%f, %f]" % (xa, xb))
  222. y = mpolyval(ser, xb-xa)
  223. xa = xb
  224. ser, xb = ode_taylor(ctx, F, xb, y, tol_prec, degree)
  225. series_boundaries.append(xb)
  226. series_data.append((ser, xa, xb))
  227. if x <= xb:
  228. return series_data[-1]
  229. # Evaluation function
  230. def interpolant(x):
  231. x = ctx.convert(x)
  232. orig = ctx.prec
  233. try:
  234. ctx.prec = workprec
  235. ser, xa, xb = get_series(x)
  236. y = mpolyval(ser, x-xa)
  237. finally:
  238. ctx.prec = orig
  239. if return_vector:
  240. return [+yk for yk in y]
  241. else:
  242. return +y[0]
  243. return interpolant
  244. ODEMethods.odefun = odefun
  245. if __name__ == "__main__":
  246. import doctest
  247. doctest.testmod()