test_streamplot.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115
  1. import sys
  2. import platform
  3. import numpy as np
  4. from numpy.testing import assert_array_almost_equal
  5. import matplotlib.pyplot as plt
  6. from matplotlib.testing.decorators import image_comparison
  7. import matplotlib.transforms as mtransforms
  8. on_win = (sys.platform == 'win32')
  9. on_mac = (sys.platform == 'darwin')
  10. def velocity_field():
  11. Y, X = np.mgrid[-3:3:100j, -3:3:100j]
  12. U = -1 - X**2 + Y
  13. V = 1 + X - Y**2
  14. return X, Y, U, V
  15. def swirl_velocity_field():
  16. x = np.linspace(-3., 3., 100)
  17. y = np.linspace(-3., 3., 100)
  18. X, Y = np.meshgrid(x, y)
  19. a = 0.1
  20. U = np.cos(a) * (-Y) - np.sin(a) * X
  21. V = np.sin(a) * (-Y) + np.cos(a) * X
  22. return x, y, U, V
  23. @image_comparison(['streamplot_startpoints'], remove_text=True, style='mpl20')
  24. def test_startpoints():
  25. X, Y, U, V = velocity_field()
  26. start_x = np.linspace(X.min(), X.max(), 10)
  27. start_y = np.linspace(Y.min(), Y.max(), 10)
  28. start_points = np.column_stack([start_x, start_y])
  29. plt.streamplot(X, Y, U, V, start_points=start_points)
  30. plt.plot(start_x, start_y, 'ok')
  31. @image_comparison(['streamplot_colormap'],
  32. tol=.04, remove_text=True, style='mpl20')
  33. def test_colormap():
  34. X, Y, U, V = velocity_field()
  35. plt.streamplot(X, Y, U, V, color=U, density=0.6, linewidth=2,
  36. cmap=plt.cm.autumn)
  37. plt.colorbar()
  38. @image_comparison(['streamplot_linewidth'], remove_text=True, style='mpl20',
  39. tol={'aarch64': 0.02}.get(platform.machine(), 0.0))
  40. def test_linewidth():
  41. X, Y, U, V = velocity_field()
  42. speed = np.hypot(U, V)
  43. lw = 5 * speed / speed.max()
  44. # Compatibility for old test image
  45. df = 25 / 30
  46. ax = plt.figure().subplots()
  47. ax.set(xlim=(-3.0, 2.9999999999999947),
  48. ylim=(-3.0000000000000004, 2.9999999999999947))
  49. ax.streamplot(X, Y, U, V, density=[0.5 * df, 1. * df], color='k',
  50. linewidth=lw)
  51. @image_comparison(['streamplot_masks_and_nans'],
  52. remove_text=True, style='mpl20', tol=0.04 if on_win else 0)
  53. def test_masks_and_nans():
  54. X, Y, U, V = velocity_field()
  55. mask = np.zeros(U.shape, dtype=bool)
  56. mask[40:60, 40:60] = 1
  57. U[:20, :20] = np.nan
  58. U = np.ma.array(U, mask=mask)
  59. # Compatibility for old test image
  60. ax = plt.figure().subplots()
  61. ax.set(xlim=(-3.0, 2.9999999999999947),
  62. ylim=(-3.0000000000000004, 2.9999999999999947))
  63. with np.errstate(invalid='ignore'):
  64. ax.streamplot(X, Y, U, V, color=U, cmap=plt.cm.Blues)
  65. @image_comparison(['streamplot_maxlength.png'],
  66. remove_text=True, style='mpl20',
  67. tol=0.002 if on_mac else 0)
  68. def test_maxlength():
  69. x, y, U, V = swirl_velocity_field()
  70. ax = plt.figure().subplots()
  71. ax.streamplot(x, y, U, V, maxlength=10., start_points=[[0., 1.5]],
  72. linewidth=2, density=2)
  73. assert ax.get_xlim()[-1] == ax.get_ylim()[-1] == 3
  74. # Compatibility for old test image
  75. ax.set(xlim=(None, 3.2555988021882305), ylim=(None, 3.078326760195413))
  76. @image_comparison(['streamplot_direction.png'],
  77. remove_text=True, style='mpl20')
  78. def test_direction():
  79. x, y, U, V = swirl_velocity_field()
  80. plt.streamplot(x, y, U, V, integration_direction='backward',
  81. maxlength=1.5, start_points=[[1.5, 0.]],
  82. linewidth=2, density=2)
  83. def test_streamplot_limits():
  84. ax = plt.axes()
  85. x = np.linspace(-5, 10, 20)
  86. y = np.linspace(-2, 4, 10)
  87. y, x = np.meshgrid(y, x)
  88. trans = mtransforms.Affine2D().translate(25, 32) + ax.transData
  89. plt.barbs(x, y, np.sin(x), np.cos(y), transform=trans)
  90. # The calculated bounds are approximately the bounds of the original data,
  91. # this is because the entire path is taken into account when updating the
  92. # datalim.
  93. assert_array_almost_equal(ax.dataLim.bounds, (20, 30, 15, 6),
  94. decimal=1)