test_matrices.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253
  1. import pytest
  2. import sys
  3. from mpmath import *
  4. def test_matrix_basic():
  5. A1 = matrix(3)
  6. for i in range(3):
  7. A1[i,i] = 1
  8. assert A1 == eye(3)
  9. assert A1 == matrix(A1)
  10. A2 = matrix(3, 2)
  11. assert not A2._matrix__data
  12. A3 = matrix([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
  13. assert list(A3) == list(range(1, 10))
  14. A3[1,1] = 0
  15. assert not (1, 1) in A3._matrix__data
  16. A4 = matrix([[1, 2, 3], [4, 5, 6]])
  17. A5 = matrix([[6, -1], [3, 2], [0, -3]])
  18. assert A4 * A5 == matrix([[12, -6], [39, -12]])
  19. assert A1 * A3 == A3 * A1 == A3
  20. pytest.raises(ValueError, lambda: A2*A2)
  21. l = [[10, 20, 30], [40, 0, 60], [70, 80, 90]]
  22. A6 = matrix(l)
  23. assert A6.tolist() == l
  24. assert A6 == eval(repr(A6))
  25. A6 = fp.matrix(A6)
  26. assert A6 == eval(repr(A6))
  27. assert A6*1j == eval(repr(A6*1j))
  28. assert A3 * 10 == 10 * A3 == A6
  29. assert A2.rows == 3
  30. assert A2.cols == 2
  31. A3.rows = 2
  32. A3.cols = 2
  33. assert len(A3._matrix__data) == 3
  34. assert A4 + A4 == 2*A4
  35. pytest.raises(ValueError, lambda: A4 + A2)
  36. assert sum(A1 - A1) == 0
  37. A7 = matrix([[1, 2], [3, 4], [5, 6], [7, 8]])
  38. x = matrix([10, -10])
  39. assert A7*x == matrix([-10, -10, -10, -10])
  40. A8 = ones(5)
  41. assert sum((A8 + 1) - (2 - zeros(5))) == 0
  42. assert (1 + ones(4)) / 2 - 1 == zeros(4)
  43. assert eye(3)**10 == eye(3)
  44. pytest.raises(ValueError, lambda: A7**2)
  45. A9 = randmatrix(3)
  46. A10 = matrix(A9)
  47. A9[0,0] = -100
  48. assert A9 != A10
  49. assert nstr(A9)
  50. def test_matmul():
  51. """
  52. Test the PEP465 "@" matrix multiplication syntax.
  53. To avoid syntax errors when importing this file in Python 3.4 and below, we have to use exec() - sorry for that.
  54. """
  55. # TODO remove exec() wrapper as soon as we drop support for Python <= 3.4
  56. if sys.hexversion < 0x30500f0:
  57. # we are on Python < 3.5
  58. pytest.skip("'@' (__matmul__) is only supported in Python 3.5 or newer")
  59. A4 = matrix([[1, 2, 3], [4, 5, 6]])
  60. A5 = matrix([[6, -1], [3, 2], [0, -3]])
  61. exec("assert A4 @ A5 == A4 * A5")
  62. def test_matrix_slices():
  63. A = matrix([ [1, 2, 3],
  64. [4, 5 ,6],
  65. [7, 8 ,9]])
  66. V = matrix([1,2,3,4,5])
  67. # Get slice
  68. assert A[:,:] == A
  69. assert A[:,1] == matrix([[2],[5],[8]])
  70. assert A[2,:] == matrix([[7, 8 ,9]])
  71. assert A[1:3,1:3] == matrix([[5,6],[8,9]])
  72. assert V[2:4] == matrix([3,4])
  73. pytest.raises(IndexError, lambda: A[:,1:6])
  74. # Assign slice with matrix
  75. A1 = matrix(3)
  76. A1[:,:] = A
  77. assert A1[:,:] == matrix([[1, 2, 3],
  78. [4, 5 ,6],
  79. [7, 8 ,9]])
  80. A1[0,:] = matrix([[10, 11, 12]])
  81. assert A1 == matrix([ [10, 11, 12],
  82. [4, 5 ,6],
  83. [7, 8 ,9]])
  84. A1[:,2] = matrix([[13], [14], [15]])
  85. assert A1 == matrix([ [10, 11, 13],
  86. [4, 5 ,14],
  87. [7, 8 ,15]])
  88. A1[:2,:2] = matrix([[16, 17], [18 , 19]])
  89. assert A1 == matrix([ [16, 17, 13],
  90. [18, 19 ,14],
  91. [7, 8 ,15]])
  92. V[1:3] = 10
  93. assert V == matrix([1,10,10,4,5])
  94. with pytest.raises(ValueError):
  95. A1[2,:] = A[:,1]
  96. with pytest.raises(IndexError):
  97. A1[2,1:20] = A[:,:]
  98. # Assign slice with scalar
  99. A1[:,2] = 10
  100. assert A1 == matrix([ [16, 17, 10],
  101. [18, 19 ,10],
  102. [7, 8 ,10]])
  103. A1[:,:] = 40
  104. for x in A1:
  105. assert x == 40
  106. def test_matrix_power():
  107. A = matrix([[1, 2], [3, 4]])
  108. assert A**2 == A*A
  109. assert A**3 == A*A*A
  110. assert A**-1 == inverse(A)
  111. assert A**-2 == inverse(A*A)
  112. def test_matrix_transform():
  113. A = matrix([[1, 2], [3, 4], [5, 6]])
  114. assert A.T == A.transpose() == matrix([[1, 3, 5], [2, 4, 6]])
  115. swap_row(A, 1, 2)
  116. assert A == matrix([[1, 2], [5, 6], [3, 4]])
  117. l = [1, 2]
  118. swap_row(l, 0, 1)
  119. assert l == [2, 1]
  120. assert extend(eye(3), [1,2,3]) == matrix([[1,0,0,1],[0,1,0,2],[0,0,1,3]])
  121. def test_matrix_conjugate():
  122. A = matrix([[1 + j, 0], [2, j]])
  123. assert A.conjugate() == matrix([[mpc(1, -1), 0], [2, mpc(0, -1)]])
  124. assert A.transpose_conj() == A.H == matrix([[mpc(1, -1), 2],
  125. [0, mpc(0, -1)]])
  126. def test_matrix_creation():
  127. assert diag([1, 2, 3]) == matrix([[1, 0, 0], [0, 2, 0], [0, 0, 3]])
  128. A1 = ones(2, 3)
  129. assert A1.rows == 2 and A1.cols == 3
  130. for a in A1:
  131. assert a == 1
  132. A2 = zeros(3, 2)
  133. assert A2.rows == 3 and A2.cols == 2
  134. for a in A2:
  135. assert a == 0
  136. assert randmatrix(10) != randmatrix(10)
  137. one = mpf(1)
  138. assert hilbert(3) == matrix([[one, one/2, one/3],
  139. [one/2, one/3, one/4],
  140. [one/3, one/4, one/5]])
  141. def test_norms():
  142. # matrix norms
  143. A = matrix([[1, -2], [-3, -1], [2, 1]])
  144. assert mnorm(A,1) == 6
  145. assert mnorm(A,inf) == 4
  146. assert mnorm(A,'F') == sqrt(20)
  147. # vector norms
  148. assert norm(-3) == 3
  149. x = [1, -2, 7, -12]
  150. assert norm(x, 1) == 22
  151. assert round(norm(x, 2), 10) == 14.0712472795
  152. assert round(norm(x, 10), 10) == 12.0054633727
  153. assert norm(x, inf) == 12
  154. def test_vector():
  155. x = matrix([0, 1, 2, 3, 4])
  156. assert x == matrix([[0], [1], [2], [3], [4]])
  157. assert x[3] == 3
  158. assert len(x._matrix__data) == 4
  159. assert list(x) == list(range(5))
  160. x[0] = -10
  161. x[4] = 0
  162. assert x[0] == -10
  163. assert len(x) == len(x.T) == 5
  164. assert x.T*x == matrix([[114]])
  165. def test_matrix_copy():
  166. A = ones(6)
  167. B = A.copy()
  168. C = +A
  169. assert A == B
  170. assert A == C
  171. B[0,0] = 0
  172. assert A != B
  173. C[0,0] = 42
  174. assert A != C
  175. def test_matrix_numpy():
  176. try:
  177. import numpy
  178. except ImportError:
  179. return
  180. l = [[1, 2], [3, 4], [5, 6]]
  181. a = numpy.array(l)
  182. assert matrix(l) == matrix(a)
  183. def test_interval_matrix_scalar_mult():
  184. """Multiplication of iv.matrix and any scalar type"""
  185. a = mpi(-1, 1)
  186. b = a + a * 2j
  187. c = mpf(42)
  188. d = c + c * 2j
  189. e = 1.234
  190. f = fp.convert(e)
  191. g = e + e * 3j
  192. h = fp.convert(g)
  193. M = iv.ones(1)
  194. for x in [a, b, c, d, e, f, g, h]:
  195. assert x * M == iv.matrix([x])
  196. assert M * x == iv.matrix([x])
  197. @pytest.mark.xfail()
  198. def test_interval_matrix_matrix_mult():
  199. """Multiplication of iv.matrix and other matrix types"""
  200. A = ones(1)
  201. B = fp.ones(1)
  202. M = iv.ones(1)
  203. for X in [A, B, M]:
  204. assert X * M == iv.matrix(X)
  205. assert X * M == X
  206. assert M * X == iv.matrix(X)
  207. assert M * X == X
  208. def test_matrix_conversion_to_iv():
  209. # Test that matrices with foreign datatypes are properly converted
  210. for other_type_eye in [eye(3), fp.eye(3), iv.eye(3)]:
  211. A = iv.matrix(other_type_eye)
  212. B = iv.eye(3)
  213. assert type(A[0,0]) == type(B[0,0])
  214. assert A.tolist() == B.tolist()
  215. def test_interval_matrix_mult_bug():
  216. # regression test for interval matrix multiplication:
  217. # result must be nonzero-width and contain the exact result
  218. x = convert('1.00000000000001') # note: this is implicitly rounded to some near mpf float value
  219. A = matrix([[x]])
  220. B = iv.matrix(A)
  221. C = iv.matrix([[x]])
  222. assert B == C
  223. B = B * B
  224. C = C * C
  225. assert B == C
  226. assert B[0, 0].delta > 1e-16
  227. assert B[0, 0].delta < 3e-16
  228. assert C[0, 0].delta > 1e-16
  229. assert C[0, 0].delta < 3e-16
  230. assert mp.mpf('1.00000000000001998401444325291756783368705994138804689654') in B[0, 0]
  231. assert mp.mpf('1.00000000000001998401444325291756783368705994138804689654') in C[0, 0]
  232. # the following caused an error before the bug was fixed
  233. assert iv.matrix(mp.eye(2)) * (iv.ones(2) + mpi(1, 2)) == iv.matrix([[mpi(2, 3), mpi(2, 3)], [mpi(2, 3), mpi(2, 3)]])