test_streamplot.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. import numpy as np
  2. from numpy.testing import assert_array_almost_equal
  3. import pytest
  4. import matplotlib.pyplot as plt
  5. from matplotlib.testing.decorators import image_comparison
  6. import matplotlib.transforms as mtransforms
  7. def velocity_field():
  8. Y, X = np.mgrid[-3:3:100j, -3:3:200j]
  9. U = -1 - X**2 + Y
  10. V = 1 + X - Y**2
  11. return X, Y, U, V
  12. def swirl_velocity_field():
  13. x = np.linspace(-3., 3., 200)
  14. y = np.linspace(-3., 3., 100)
  15. X, Y = np.meshgrid(x, y)
  16. a = 0.1
  17. U = np.cos(a) * (-Y) - np.sin(a) * X
  18. V = np.sin(a) * (-Y) + np.cos(a) * X
  19. return x, y, U, V
  20. @image_comparison(['streamplot_startpoints'], remove_text=True, style='mpl20',
  21. extensions=['png'])
  22. def test_startpoints():
  23. X, Y, U, V = velocity_field()
  24. start_x, start_y = np.meshgrid(np.linspace(X.min(), X.max(), 5),
  25. np.linspace(Y.min(), Y.max(), 5))
  26. start_points = np.column_stack([start_x.ravel(), start_y.ravel()])
  27. plt.streamplot(X, Y, U, V, start_points=start_points)
  28. plt.plot(start_x, start_y, 'ok')
  29. @image_comparison(['streamplot_colormap'], remove_text=True, style='mpl20',
  30. tol=0.022)
  31. def test_colormap():
  32. X, Y, U, V = velocity_field()
  33. plt.streamplot(X, Y, U, V, color=U, density=0.6, linewidth=2,
  34. cmap=plt.cm.autumn)
  35. plt.colorbar()
  36. @image_comparison(['streamplot_linewidth'], remove_text=True, style='mpl20',
  37. tol=0.002)
  38. def test_linewidth():
  39. X, Y, U, V = velocity_field()
  40. speed = np.hypot(U, V)
  41. lw = 5 * speed / speed.max()
  42. ax = plt.figure().subplots()
  43. ax.streamplot(X, Y, U, V, density=[0.5, 1], color='k', linewidth=lw)
  44. @image_comparison(['streamplot_masks_and_nans'],
  45. remove_text=True, style='mpl20')
  46. def test_masks_and_nans():
  47. X, Y, U, V = velocity_field()
  48. mask = np.zeros(U.shape, dtype=bool)
  49. mask[40:60, 80:120] = 1
  50. U[:20, :40] = np.nan
  51. U = np.ma.array(U, mask=mask)
  52. ax = plt.figure().subplots()
  53. with np.errstate(invalid='ignore'):
  54. ax.streamplot(X, Y, U, V, color=U, cmap=plt.cm.Blues)
  55. @image_comparison(['streamplot_maxlength.png'],
  56. remove_text=True, style='mpl20', tol=0.302)
  57. def test_maxlength():
  58. x, y, U, V = swirl_velocity_field()
  59. ax = plt.figure().subplots()
  60. ax.streamplot(x, y, U, V, maxlength=10., start_points=[[0., 1.5]],
  61. linewidth=2, density=2)
  62. assert ax.get_xlim()[-1] == ax.get_ylim()[-1] == 3
  63. # Compatibility for old test image
  64. ax.set(xlim=(None, 3.2555988021882305), ylim=(None, 3.078326760195413))
  65. @image_comparison(['streamplot_maxlength_no_broken.png'],
  66. remove_text=True, style='mpl20', tol=0.302)
  67. def test_maxlength_no_broken():
  68. x, y, U, V = swirl_velocity_field()
  69. ax = plt.figure().subplots()
  70. ax.streamplot(x, y, U, V, maxlength=10., start_points=[[0., 1.5]],
  71. linewidth=2, density=2, broken_streamlines=False)
  72. assert ax.get_xlim()[-1] == ax.get_ylim()[-1] == 3
  73. # Compatibility for old test image
  74. ax.set(xlim=(None, 3.2555988021882305), ylim=(None, 3.078326760195413))
  75. @image_comparison(['streamplot_direction.png'],
  76. remove_text=True, style='mpl20', tol=0.073)
  77. def test_direction():
  78. x, y, U, V = swirl_velocity_field()
  79. plt.streamplot(x, y, U, V, integration_direction='backward',
  80. maxlength=1.5, start_points=[[1.5, 0.]],
  81. linewidth=2, density=2)
  82. def test_streamplot_limits():
  83. ax = plt.axes()
  84. x = np.linspace(-5, 10, 20)
  85. y = np.linspace(-2, 4, 10)
  86. y, x = np.meshgrid(y, x)
  87. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  88. plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
  89. # The calculated bounds are approximately the bounds of the original data,
  90. # this is because the entire path is taken into account when updating the
  91. # datalim.
  92. assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
  93. decimal=1)
  94. def test_streamplot_grid():
  95. u = np.ones((2, 2))
  96. v = np.zeros((2, 2))
  97. # Test for same rows and columns
  98. x = np.array([[10, 20], [10, 30]])
  99. y = np.array([[10, 10], [20, 20]])
  100. with pytest.raises(ValueError, match="The rows of 'x' must be equal"):
  101. plt.streamplot(x, y, u, v)
  102. x = np.array([[10, 20], [10, 20]])
  103. y = np.array([[10, 10], [20, 30]])
  104. with pytest.raises(ValueError, match="The columns of 'y' must be equal"):
  105. plt.streamplot(x, y, u, v)
  106. x = np.array([[10, 20], [10, 20]])
  107. y = np.array([[10, 10], [20, 20]])
  108. plt.streamplot(x, y, u, v)
  109. # Test for maximum dimensions
  110. x = np.array([0, 10])
  111. y = np.array([[[0, 10]]])
  112. with pytest.raises(ValueError, match="'y' can have at maximum "
  113. "2 dimensions"):
  114. plt.streamplot(x, y, u, v)
  115. # Test for equal spacing
  116. u = np.ones((3, 3))
  117. v = np.zeros((3, 3))
  118. x = np.array([0, 10, 20])
  119. y = np.array([0, 10, 30])
  120. with pytest.raises(ValueError, match="'y' values must be equally spaced"):
  121. plt.streamplot(x, y, u, v)
  122. # Test for strictly increasing
  123. x = np.array([0, 20, 40])
  124. y = np.array([0, 20, 10])
  125. with pytest.raises(ValueError, match="'y' must be strictly increasing"):
  126. plt.streamplot(x, y, u, v)
  127. def test_streamplot_inputs(): # test no exception occurs.
  128. # fully-masked
  129. plt.streamplot(np.arange(3), np.arange(3),
  130. np.full((3, 3), np.nan), np.full((3, 3), np.nan),
  131. color=np.random.rand(3, 3))
  132. # array-likes
  133. plt.streamplot(range(3), range(3),
  134. np.random.rand(3, 3), np.random.rand(3, 3))