test_units.py 5.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177
  1. from datetime import datetime
  2. import platform
  3. from unittest.mock import MagicMock
  4. import matplotlib.pyplot as plt
  5. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  6. import matplotlib.units as munits
  7. import numpy as np
  8. import pytest
  9. # Basic class that wraps numpy array and has units
  10. class Quantity:
  11. def __init__(self, data, units):
  12. self.magnitude = data
  13. self.units = units
  14. def to(self, new_units):
  15. factors = {('hours', 'seconds'): 3600, ('minutes', 'hours'): 1 / 60,
  16. ('minutes', 'seconds'): 60, ('feet', 'miles'): 1 / 5280.,
  17. ('feet', 'inches'): 12, ('miles', 'inches'): 12 * 5280}
  18. if self.units != new_units:
  19. mult = factors[self.units, new_units]
  20. return Quantity(mult * self.magnitude, new_units)
  21. else:
  22. return Quantity(self.magnitude, self.units)
  23. def __getattr__(self, attr):
  24. return getattr(self.magnitude, attr)
  25. def __getitem__(self, item):
  26. if np.iterable(self.magnitude):
  27. return Quantity(self.magnitude[item], self.units)
  28. else:
  29. return Quantity(self.magnitude, self.units)
  30. def __array__(self):
  31. return np.asarray(self.magnitude)
  32. @pytest.fixture
  33. def quantity_converter():
  34. # Create an instance of the conversion interface and
  35. # mock so we can check methods called
  36. qc = munits.ConversionInterface()
  37. def convert(value, unit, axis):
  38. if hasattr(value, 'units'):
  39. return value.to(unit).magnitude
  40. elif np.iterable(value):
  41. try:
  42. return [v.to(unit).magnitude for v in value]
  43. except AttributeError:
  44. return [Quantity(v, axis.get_units()).to(unit).magnitude
  45. for v in value]
  46. else:
  47. return Quantity(value, axis.get_units()).to(unit).magnitude
  48. def default_units(value, axis):
  49. if hasattr(value, 'units'):
  50. return value.units
  51. elif np.iterable(value):
  52. for v in value:
  53. if hasattr(v, 'units'):
  54. return v.units
  55. return None
  56. qc.convert = MagicMock(side_effect=convert)
  57. qc.axisinfo = MagicMock(side_effect=lambda u, a: munits.AxisInfo(label=u))
  58. qc.default_units = MagicMock(side_effect=default_units)
  59. return qc
  60. # Tests that the conversion machinery works properly for classes that
  61. # work as a facade over numpy arrays (like pint)
  62. @image_comparison(['plot_pint.png'], remove_text=False, style='mpl20',
  63. tol={'aarch64': 0.02}.get(platform.machine(), 0.0))
  64. def test_numpy_facade(quantity_converter):
  65. # use former defaults to match existing baseline image
  66. plt.rcParams['axes.formatter.limits'] = -7, 7
  67. # Register the class
  68. munits.registry[Quantity] = quantity_converter
  69. # Simple test
  70. y = Quantity(np.linspace(0, 30), 'miles')
  71. x = Quantity(np.linspace(0, 5), 'hours')
  72. fig, ax = plt.subplots()
  73. fig.subplots_adjust(left=0.15) # Make space for label
  74. ax.plot(x, y, 'tab:blue')
  75. ax.axhline(Quantity(26400, 'feet'), color='tab:red')
  76. ax.axvline(Quantity(120, 'minutes'), color='tab:green')
  77. ax.yaxis.set_units('inches')
  78. ax.xaxis.set_units('seconds')
  79. assert quantity_converter.convert.called
  80. assert quantity_converter.axisinfo.called
  81. assert quantity_converter.default_units.called
  82. # Tests gh-8908
  83. @image_comparison(['plot_masked_units.png'], remove_text=True, style='mpl20',
  84. tol={'aarch64': 0.02}.get(platform.machine(), 0.0))
  85. def test_plot_masked_units():
  86. data = np.linspace(-5, 5)
  87. data_masked = np.ma.array(data, mask=(data > -2) & (data < 2))
  88. data_masked_units = Quantity(data_masked, 'meters')
  89. fig, ax = plt.subplots()
  90. ax.plot(data_masked_units)
  91. def test_empty_set_limits_with_units(quantity_converter):
  92. # Register the class
  93. munits.registry[Quantity] = quantity_converter
  94. fig, ax = plt.subplots()
  95. ax.set_xlim(Quantity(-1, 'meters'), Quantity(6, 'meters'))
  96. ax.set_ylim(Quantity(-1, 'hours'), Quantity(16, 'hours'))
  97. @image_comparison(['jpl_bar_units.png'],
  98. savefig_kwarg={'dpi': 120}, style='mpl20')
  99. def test_jpl_bar_units():
  100. import matplotlib.testing.jpl_units as units
  101. units.register()
  102. day = units.Duration("ET", 24.0 * 60.0 * 60.0)
  103. x = [0*units.km, 1*units.km, 2*units.km]
  104. w = [1*day, 2*day, 3*day]
  105. b = units.Epoch("ET", dt=datetime(2009, 4, 25))
  106. fig, ax = plt.subplots()
  107. ax.bar(x, w, bottom=b)
  108. ax.set_ylim([b-1*day, b+w[-1]+1*day])
  109. @image_comparison(['jpl_barh_units.png'],
  110. savefig_kwarg={'dpi': 120}, style='mpl20')
  111. def test_jpl_barh_units():
  112. import matplotlib.testing.jpl_units as units
  113. units.register()
  114. day = units.Duration("ET", 24.0 * 60.0 * 60.0)
  115. x = [0*units.km, 1*units.km, 2*units.km]
  116. w = [1*day, 2*day, 3*day]
  117. b = units.Epoch("ET", dt=datetime(2009, 4, 25))
  118. fig, ax = plt.subplots()
  119. ax.barh(x, w, left=b)
  120. ax.set_xlim([b-1*day, b+w[-1]+1*day])
  121. def test_empty_arrays():
  122. # Check that plotting an empty array with a dtype works
  123. plt.scatter(np.array([], dtype='datetime64[ns]'), np.array([]))
  124. def test_scatter_element0_masked():
  125. times = np.arange('2005-02', '2005-03', dtype='datetime64[D]')
  126. y = np.arange(len(times), dtype='float')
  127. y[0] = np.nan
  128. fig, ax = plt.subplots()
  129. ax.scatter(times, y)
  130. fig.canvas.draw()
  131. @check_figures_equal(extensions=["png"])
  132. def test_subclass(fig_test, fig_ref):
  133. class subdate(datetime):
  134. pass
  135. fig_test.subplots().plot(subdate(2000, 1, 1), 0, "o")
  136. fig_ref.subplots().plot(datetime(2000, 1, 1), 0, "o")