calculus.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531
  1. from ..libmp.backend import xrange
  2. # TODO: should use diagonalization-based algorithms
  3. class MatrixCalculusMethods(object):
  4. def _exp_pade(ctx, a):
  5. """
  6. Exponential of a matrix using Pade approximants.
  7. See G. H. Golub, C. F. van Loan 'Matrix Computations',
  8. third Ed., page 572
  9. TODO:
  10. - find a good estimate for q
  11. - reduce the number of matrix multiplications to improve
  12. performance
  13. """
  14. def eps_pade(p):
  15. return ctx.mpf(2)**(3-2*p) * \
  16. ctx.factorial(p)**2/(ctx.factorial(2*p)**2 * (2*p + 1))
  17. q = 4
  18. extraq = 8
  19. while 1:
  20. if eps_pade(q) < ctx.eps:
  21. break
  22. q += 1
  23. q += extraq
  24. j = int(max(1, ctx.mag(ctx.mnorm(a,'inf'))))
  25. extra = q
  26. prec = ctx.prec
  27. ctx.dps += extra + 3
  28. try:
  29. a = a/2**j
  30. na = a.rows
  31. den = ctx.eye(na)
  32. num = ctx.eye(na)
  33. x = ctx.eye(na)
  34. c = ctx.mpf(1)
  35. for k in range(1, q+1):
  36. c *= ctx.mpf(q - k + 1)/((2*q - k + 1) * k)
  37. x = a*x
  38. cx = c*x
  39. num += cx
  40. den += (-1)**k * cx
  41. f = ctx.lu_solve_mat(den, num)
  42. for k in range(j):
  43. f = f*f
  44. finally:
  45. ctx.prec = prec
  46. return f*1
  47. def expm(ctx, A, method='taylor'):
  48. r"""
  49. Computes the matrix exponential of a square matrix `A`, which is defined
  50. by the power series
  51. .. math ::
  52. \exp(A) = I + A + \frac{A^2}{2!} + \frac{A^3}{3!} + \ldots
  53. With method='taylor', the matrix exponential is computed
  54. using the Taylor series. With method='pade', Pade approximants
  55. are used instead.
  56. **Examples**
  57. Basic examples::
  58. >>> from mpmath import *
  59. >>> mp.dps = 15; mp.pretty = True
  60. >>> expm(zeros(3))
  61. [1.0 0.0 0.0]
  62. [0.0 1.0 0.0]
  63. [0.0 0.0 1.0]
  64. >>> expm(eye(3))
  65. [2.71828182845905 0.0 0.0]
  66. [ 0.0 2.71828182845905 0.0]
  67. [ 0.0 0.0 2.71828182845905]
  68. >>> expm([[1,1,0],[1,0,1],[0,1,0]])
  69. [ 3.86814500615414 2.26812870852145 0.841130841230196]
  70. [ 2.26812870852145 2.44114713886289 1.42699786729125]
  71. [0.841130841230196 1.42699786729125 1.6000162976327]
  72. >>> expm([[1,1,0],[1,0,1],[0,1,0]], method='pade')
  73. [ 3.86814500615414 2.26812870852145 0.841130841230196]
  74. [ 2.26812870852145 2.44114713886289 1.42699786729125]
  75. [0.841130841230196 1.42699786729125 1.6000162976327]
  76. >>> expm([[1+j, 0], [1+j,1]])
  77. [(1.46869393991589 + 2.28735528717884j) 0.0]
  78. [ (1.03776739863568 + 3.536943175722j) (2.71828182845905 + 0.0j)]
  79. Matrices with large entries are allowed::
  80. >>> expm(matrix([[1,2],[2,3]])**25)
  81. [5.65024064048415e+2050488462815550 9.14228140091932e+2050488462815550]
  82. [9.14228140091932e+2050488462815550 1.47925220414035e+2050488462815551]
  83. The identity `\exp(A+B) = \exp(A) \exp(B)` does not hold for
  84. noncommuting matrices::
  85. >>> A = hilbert(3)
  86. >>> B = A + eye(3)
  87. >>> chop(mnorm(A*B - B*A))
  88. 0.0
  89. >>> chop(mnorm(expm(A+B) - expm(A)*expm(B)))
  90. 0.0
  91. >>> B = A + ones(3)
  92. >>> mnorm(A*B - B*A)
  93. 1.8
  94. >>> mnorm(expm(A+B) - expm(A)*expm(B))
  95. 42.0927851137247
  96. """
  97. if method == 'pade':
  98. prec = ctx.prec
  99. try:
  100. A = ctx.matrix(A)
  101. ctx.prec += 2*A.rows
  102. res = ctx._exp_pade(A)
  103. finally:
  104. ctx.prec = prec
  105. return res
  106. A = ctx.matrix(A)
  107. prec = ctx.prec
  108. j = int(max(1, ctx.mag(ctx.mnorm(A,'inf'))))
  109. j += int(0.5*prec**0.5)
  110. try:
  111. ctx.prec += 10 + 2*j
  112. tol = +ctx.eps
  113. A = A/2**j
  114. T = A
  115. Y = A**0 + A
  116. k = 2
  117. while 1:
  118. T *= A * (1/ctx.mpf(k))
  119. if ctx.mnorm(T, 'inf') < tol:
  120. break
  121. Y += T
  122. k += 1
  123. for k in xrange(j):
  124. Y = Y*Y
  125. finally:
  126. ctx.prec = prec
  127. Y *= 1
  128. return Y
  129. def cosm(ctx, A):
  130. r"""
  131. Gives the cosine of a square matrix `A`, defined in analogy
  132. with the matrix exponential.
  133. Examples::
  134. >>> from mpmath import *
  135. >>> mp.dps = 15; mp.pretty = True
  136. >>> X = eye(3)
  137. >>> cosm(X)
  138. [0.54030230586814 0.0 0.0]
  139. [ 0.0 0.54030230586814 0.0]
  140. [ 0.0 0.0 0.54030230586814]
  141. >>> X = hilbert(3)
  142. >>> cosm(X)
  143. [ 0.424403834569555 -0.316643413047167 -0.221474945949293]
  144. [-0.316643413047167 0.820646708837824 -0.127183694770039]
  145. [-0.221474945949293 -0.127183694770039 0.909236687217541]
  146. >>> X = matrix([[1+j,-2],[0,-j]])
  147. >>> cosm(X)
  148. [(0.833730025131149 - 0.988897705762865j) (1.07485840848393 - 0.17192140544213j)]
  149. [ 0.0 (1.54308063481524 + 0.0j)]
  150. """
  151. B = 0.5 * (ctx.expm(A*ctx.j) + ctx.expm(A*(-ctx.j)))
  152. if not sum(A.apply(ctx.im).apply(abs)):
  153. B = B.apply(ctx.re)
  154. return B
  155. def sinm(ctx, A):
  156. r"""
  157. Gives the sine of a square matrix `A`, defined in analogy
  158. with the matrix exponential.
  159. Examples::
  160. >>> from mpmath import *
  161. >>> mp.dps = 15; mp.pretty = True
  162. >>> X = eye(3)
  163. >>> sinm(X)
  164. [0.841470984807897 0.0 0.0]
  165. [ 0.0 0.841470984807897 0.0]
  166. [ 0.0 0.0 0.841470984807897]
  167. >>> X = hilbert(3)
  168. >>> sinm(X)
  169. [0.711608512150994 0.339783913247439 0.220742837314741]
  170. [0.339783913247439 0.244113865695532 0.187231271174372]
  171. [0.220742837314741 0.187231271174372 0.155816730769635]
  172. >>> X = matrix([[1+j,-2],[0,-j]])
  173. >>> sinm(X)
  174. [(1.29845758141598 + 0.634963914784736j) (-1.96751511930922 + 0.314700021761367j)]
  175. [ 0.0 (0.0 - 1.1752011936438j)]
  176. """
  177. B = (-0.5j) * (ctx.expm(A*ctx.j) - ctx.expm(A*(-ctx.j)))
  178. if not sum(A.apply(ctx.im).apply(abs)):
  179. B = B.apply(ctx.re)
  180. return B
  181. def _sqrtm_rot(ctx, A, _may_rotate):
  182. # If the iteration fails to converge, cheat by performing
  183. # a rotation by a complex number
  184. u = ctx.j**0.3
  185. return ctx.sqrtm(u*A, _may_rotate) / ctx.sqrt(u)
  186. def sqrtm(ctx, A, _may_rotate=2):
  187. r"""
  188. Computes a square root of the square matrix `A`, i.e. returns
  189. a matrix `B = A^{1/2}` such that `B^2 = A`. The square root
  190. of a matrix, if it exists, is not unique.
  191. **Examples**
  192. Square roots of some simple matrices::
  193. >>> from mpmath import *
  194. >>> mp.dps = 15; mp.pretty = True
  195. >>> sqrtm([[1,0], [0,1]])
  196. [1.0 0.0]
  197. [0.0 1.0]
  198. >>> sqrtm([[0,0], [0,0]])
  199. [0.0 0.0]
  200. [0.0 0.0]
  201. >>> sqrtm([[2,0],[0,1]])
  202. [1.4142135623731 0.0]
  203. [ 0.0 1.0]
  204. >>> sqrtm([[1,1],[1,0]])
  205. [ (0.920442065259926 - 0.21728689675164j) (0.568864481005783 + 0.351577584254143j)]
  206. [(0.568864481005783 + 0.351577584254143j) (0.351577584254143 - 0.568864481005783j)]
  207. >>> sqrtm([[1,0],[0,1]])
  208. [1.0 0.0]
  209. [0.0 1.0]
  210. >>> sqrtm([[-1,0],[0,1]])
  211. [(0.0 - 1.0j) 0.0]
  212. [ 0.0 (1.0 + 0.0j)]
  213. >>> sqrtm([[j,0],[0,j]])
  214. [(0.707106781186547 + 0.707106781186547j) 0.0]
  215. [ 0.0 (0.707106781186547 + 0.707106781186547j)]
  216. A square root of a rotation matrix, giving the corresponding
  217. half-angle rotation matrix::
  218. >>> t1 = 0.75
  219. >>> t2 = t1 * 0.5
  220. >>> A1 = matrix([[cos(t1), -sin(t1)], [sin(t1), cos(t1)]])
  221. >>> A2 = matrix([[cos(t2), -sin(t2)], [sin(t2), cos(t2)]])
  222. >>> sqrtm(A1)
  223. [0.930507621912314 -0.366272529086048]
  224. [0.366272529086048 0.930507621912314]
  225. >>> A2
  226. [0.930507621912314 -0.366272529086048]
  227. [0.366272529086048 0.930507621912314]
  228. The identity `(A^2)^{1/2} = A` does not necessarily hold::
  229. >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
  230. >>> sqrtm(A**2)
  231. [ 4.0 1.0 4.0]
  232. [ 7.0 8.0 9.0]
  233. [10.0 2.0 11.0]
  234. >>> sqrtm(A)**2
  235. [ 4.0 1.0 4.0]
  236. [ 7.0 8.0 9.0]
  237. [10.0 2.0 11.0]
  238. >>> A = matrix([[-4,1,4],[7,-8,9],[10,2,11]])
  239. >>> sqrtm(A**2)
  240. [ 7.43715112194995 -0.324127569985474 1.8481718827526]
  241. [-0.251549715716942 9.32699765900402 2.48221180985147]
  242. [ 4.11609388833616 0.775751877098258 13.017955697342]
  243. >>> chop(sqrtm(A)**2)
  244. [-4.0 1.0 4.0]
  245. [ 7.0 -8.0 9.0]
  246. [10.0 2.0 11.0]
  247. For some matrices, a square root does not exist::
  248. >>> sqrtm([[0,1], [0,0]])
  249. Traceback (most recent call last):
  250. ...
  251. ZeroDivisionError: matrix is numerically singular
  252. Two examples from the documentation for Matlab's ``sqrtm``::
  253. >>> mp.dps = 15; mp.pretty = True
  254. >>> sqrtm([[7,10],[15,22]])
  255. [1.56669890360128 1.74077655955698]
  256. [2.61116483933547 4.17786374293675]
  257. >>>
  258. >>> X = matrix(\
  259. ... [[5,-4,1,0,0],
  260. ... [-4,6,-4,1,0],
  261. ... [1,-4,6,-4,1],
  262. ... [0,1,-4,6,-4],
  263. ... [0,0,1,-4,5]])
  264. >>> Y = matrix(\
  265. ... [[2,-1,-0,-0,-0],
  266. ... [-1,2,-1,0,-0],
  267. ... [0,-1,2,-1,0],
  268. ... [-0,0,-1,2,-1],
  269. ... [-0,-0,-0,-1,2]])
  270. >>> mnorm(sqrtm(X) - Y)
  271. 4.53155328326114e-19
  272. """
  273. A = ctx.matrix(A)
  274. # Trivial
  275. if A*0 == A:
  276. return A
  277. prec = ctx.prec
  278. if _may_rotate:
  279. d = ctx.det(A)
  280. if abs(ctx.im(d)) < 16*ctx.eps and ctx.re(d) < 0:
  281. return ctx._sqrtm_rot(A, _may_rotate-1)
  282. try:
  283. ctx.prec += 10
  284. tol = ctx.eps * 128
  285. Y = A
  286. Z = I = A**0
  287. k = 0
  288. # Denman-Beavers iteration
  289. while 1:
  290. Yprev = Y
  291. try:
  292. Y, Z = 0.5*(Y+ctx.inverse(Z)), 0.5*(Z+ctx.inverse(Y))
  293. except ZeroDivisionError:
  294. if _may_rotate:
  295. Y = ctx._sqrtm_rot(A, _may_rotate-1)
  296. break
  297. else:
  298. raise
  299. mag1 = ctx.mnorm(Y-Yprev, 'inf')
  300. mag2 = ctx.mnorm(Y, 'inf')
  301. if mag1 <= mag2*tol:
  302. break
  303. if _may_rotate and k > 6 and not mag1 < mag2 * 0.001:
  304. return ctx._sqrtm_rot(A, _may_rotate-1)
  305. k += 1
  306. if k > ctx.prec:
  307. raise ctx.NoConvergence
  308. finally:
  309. ctx.prec = prec
  310. Y *= 1
  311. return Y
  312. def logm(ctx, A):
  313. r"""
  314. Computes a logarithm of the square matrix `A`, i.e. returns
  315. a matrix `B = \log(A)` such that `\exp(B) = A`. The logarithm
  316. of a matrix, if it exists, is not unique.
  317. **Examples**
  318. Logarithms of some simple matrices::
  319. >>> from mpmath import *
  320. >>> mp.dps = 15; mp.pretty = True
  321. >>> X = eye(3)
  322. >>> logm(X)
  323. [0.0 0.0 0.0]
  324. [0.0 0.0 0.0]
  325. [0.0 0.0 0.0]
  326. >>> logm(2*X)
  327. [0.693147180559945 0.0 0.0]
  328. [ 0.0 0.693147180559945 0.0]
  329. [ 0.0 0.0 0.693147180559945]
  330. >>> logm(expm(X))
  331. [1.0 0.0 0.0]
  332. [0.0 1.0 0.0]
  333. [0.0 0.0 1.0]
  334. A logarithm of a complex matrix::
  335. >>> X = matrix([[2+j, 1, 3], [1-j, 1-2*j, 1], [-4, -5, j]])
  336. >>> B = logm(X)
  337. >>> nprint(B)
  338. [ (0.808757 + 0.107759j) (2.20752 + 0.202762j) (1.07376 - 0.773874j)]
  339. [ (0.905709 - 0.107795j) (0.0287395 - 0.824993j) (0.111619 + 0.514272j)]
  340. [(-0.930151 + 0.399512j) (-2.06266 - 0.674397j) (0.791552 + 0.519839j)]
  341. >>> chop(expm(B))
  342. [(2.0 + 1.0j) 1.0 3.0]
  343. [(1.0 - 1.0j) (1.0 - 2.0j) 1.0]
  344. [ -4.0 -5.0 (0.0 + 1.0j)]
  345. A matrix `X` close to the identity matrix, for which
  346. `\log(\exp(X)) = \exp(\log(X)) = X` holds::
  347. >>> X = eye(3) + hilbert(3)/4
  348. >>> X
  349. [ 1.25 0.125 0.0833333333333333]
  350. [ 0.125 1.08333333333333 0.0625]
  351. [0.0833333333333333 0.0625 1.05]
  352. >>> logm(expm(X))
  353. [ 1.25 0.125 0.0833333333333333]
  354. [ 0.125 1.08333333333333 0.0625]
  355. [0.0833333333333333 0.0625 1.05]
  356. >>> expm(logm(X))
  357. [ 1.25 0.125 0.0833333333333333]
  358. [ 0.125 1.08333333333333 0.0625]
  359. [0.0833333333333333 0.0625 1.05]
  360. A logarithm of a rotation matrix, giving back the angle of
  361. the rotation::
  362. >>> t = 3.7
  363. >>> A = matrix([[cos(t),sin(t)],[-sin(t),cos(t)]])
  364. >>> chop(logm(A))
  365. [ 0.0 -2.58318530717959]
  366. [2.58318530717959 0.0]
  367. >>> (2*pi-t)
  368. 2.58318530717959
  369. For some matrices, a logarithm does not exist::
  370. >>> logm([[1,0], [0,0]])
  371. Traceback (most recent call last):
  372. ...
  373. ZeroDivisionError: matrix is numerically singular
  374. Logarithm of a matrix with large entries::
  375. >>> logm(hilbert(3) * 10**20).apply(re)
  376. [ 45.5597513593433 1.27721006042799 0.317662687717978]
  377. [ 1.27721006042799 42.5222778973542 2.24003708791604]
  378. [0.317662687717978 2.24003708791604 42.395212822267]
  379. """
  380. A = ctx.matrix(A)
  381. prec = ctx.prec
  382. try:
  383. ctx.prec += 10
  384. tol = ctx.eps * 128
  385. I = A**0
  386. B = A
  387. n = 0
  388. while 1:
  389. B = ctx.sqrtm(B)
  390. n += 1
  391. if ctx.mnorm(B-I, 'inf') < 0.125:
  392. break
  393. T = X = B-I
  394. L = X*0
  395. k = 1
  396. while 1:
  397. if k & 1:
  398. L += T / k
  399. else:
  400. L -= T / k
  401. T *= X
  402. if ctx.mnorm(T, 'inf') < tol:
  403. break
  404. k += 1
  405. if k > ctx.prec:
  406. raise ctx.NoConvergence
  407. finally:
  408. ctx.prec = prec
  409. L *= 2**n
  410. return L
  411. def powm(ctx, A, r):
  412. r"""
  413. Computes `A^r = \exp(A \log r)` for a matrix `A` and complex
  414. number `r`.
  415. **Examples**
  416. Powers and inverse powers of a matrix::
  417. >>> from mpmath import *
  418. >>> mp.dps = 15; mp.pretty = True
  419. >>> A = matrix([[4,1,4],[7,8,9],[10,2,11]])
  420. >>> powm(A, 2)
  421. [ 63.0 20.0 69.0]
  422. [174.0 89.0 199.0]
  423. [164.0 48.0 179.0]
  424. >>> chop(powm(powm(A, 4), 1/4.))
  425. [ 4.0 1.0 4.0]
  426. [ 7.0 8.0 9.0]
  427. [10.0 2.0 11.0]
  428. >>> powm(extraprec(20)(powm)(A, -4), -1/4.)
  429. [ 4.0 1.0 4.0]
  430. [ 7.0 8.0 9.0]
  431. [10.0 2.0 11.0]
  432. >>> chop(powm(powm(A, 1+0.5j), 1/(1+0.5j)))
  433. [ 4.0 1.0 4.0]
  434. [ 7.0 8.0 9.0]
  435. [10.0 2.0 11.0]
  436. >>> powm(extraprec(5)(powm)(A, -1.5), -1/(1.5))
  437. [ 4.0 1.0 4.0]
  438. [ 7.0 8.0 9.0]
  439. [10.0 2.0 11.0]
  440. A Fibonacci-generating matrix::
  441. >>> powm([[1,1],[1,0]], 10)
  442. [89.0 55.0]
  443. [55.0 34.0]
  444. >>> fib(10)
  445. 55.0
  446. >>> powm([[1,1],[1,0]], 6.5)
  447. [(16.5166626964253 - 0.0121089837381789j) (10.2078589271083 + 0.0195927472575932j)]
  448. [(10.2078589271083 + 0.0195927472575932j) (6.30880376931698 - 0.0317017309957721j)]
  449. >>> (phi**6.5 - (1-phi)**6.5)/sqrt(5)
  450. (10.2078589271083 - 0.0195927472575932j)
  451. >>> powm([[1,1],[1,0]], 6.2)
  452. [ (14.3076953002666 - 0.008222855781077j) (8.81733464837593 + 0.0133048601383712j)]
  453. [(8.81733464837593 + 0.0133048601383712j) (5.49036065189071 - 0.0215277159194482j)]
  454. >>> (phi**6.2 - (1-phi)**6.2)/sqrt(5)
  455. (8.81733464837593 - 0.0133048601383712j)
  456. """
  457. A = ctx.matrix(A)
  458. r = ctx.convert(r)
  459. prec = ctx.prec
  460. try:
  461. ctx.prec += 10
  462. if ctx.isint(r):
  463. v = A ** int(r)
  464. elif ctx.isint(r*2):
  465. y = int(r*2)
  466. v = ctx.sqrtm(A) ** y
  467. else:
  468. v = ctx.expm(r*ctx.logm(A))
  469. finally:
  470. ctx.prec = prec
  471. v *= 1
  472. return v