test_scale.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295
  1. import copy
  2. import matplotlib.pyplot as plt
  3. from matplotlib.scale import (
  4. AsinhScale, AsinhTransform,
  5. LogTransform, InvertedLogTransform,
  6. SymmetricalLogTransform)
  7. import matplotlib.scale as mscale
  8. from matplotlib.ticker import AsinhLocator, LogFormatterSciNotation
  9. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  10. import numpy as np
  11. from numpy.testing import assert_allclose
  12. import io
  13. import pytest
  14. @check_figures_equal()
  15. def test_log_scales(fig_test, fig_ref):
  16. ax_test = fig_test.add_subplot(122, yscale='log', xscale='symlog')
  17. ax_test.axvline(24.1)
  18. ax_test.axhline(24.1)
  19. xlim = ax_test.get_xlim()
  20. ylim = ax_test.get_ylim()
  21. ax_ref = fig_ref.add_subplot(122, yscale='log', xscale='symlog')
  22. ax_ref.set(xlim=xlim, ylim=ylim)
  23. ax_ref.plot([24.1, 24.1], ylim, 'b')
  24. ax_ref.plot(xlim, [24.1, 24.1], 'b')
  25. def test_symlog_mask_nan():
  26. # Use a transform round-trip to verify that the forward and inverse
  27. # transforms work, and that they respect nans and/or masking.
  28. slt = SymmetricalLogTransform(10, 2, 1)
  29. slti = slt.inverted()
  30. x = np.arange(-1.5, 5, 0.5)
  31. out = slti.transform_non_affine(slt.transform_non_affine(x))
  32. assert_allclose(out, x)
  33. assert type(out) is type(x)
  34. x[4] = np.nan
  35. out = slti.transform_non_affine(slt.transform_non_affine(x))
  36. assert_allclose(out, x)
  37. assert type(out) is type(x)
  38. x = np.ma.array(x)
  39. out = slti.transform_non_affine(slt.transform_non_affine(x))
  40. assert_allclose(out, x)
  41. assert type(out) is type(x)
  42. x[3] = np.ma.masked
  43. out = slti.transform_non_affine(slt.transform_non_affine(x))
  44. assert_allclose(out, x)
  45. assert type(out) is type(x)
  46. @image_comparison(['logit_scales.png'], remove_text=True)
  47. def test_logit_scales():
  48. fig, ax = plt.subplots()
  49. # Typical extinction curve for logit
  50. x = np.array([0.001, 0.003, 0.01, 0.03, 0.1, 0.2, 0.3, 0.4, 0.5,
  51. 0.6, 0.7, 0.8, 0.9, 0.97, 0.99, 0.997, 0.999])
  52. y = 1.0 / x
  53. ax.plot(x, y)
  54. ax.set_xscale('logit')
  55. ax.grid(True)
  56. bbox = ax.get_tightbbox(fig.canvas.get_renderer())
  57. assert np.isfinite(bbox.x0)
  58. assert np.isfinite(bbox.y0)
  59. def test_log_scatter():
  60. """Issue #1799"""
  61. fig, ax = plt.subplots(1)
  62. x = np.arange(10)
  63. y = np.arange(10) - 1
  64. ax.scatter(x, y)
  65. buf = io.BytesIO()
  66. fig.savefig(buf, format='pdf')
  67. buf = io.BytesIO()
  68. fig.savefig(buf, format='eps')
  69. buf = io.BytesIO()
  70. fig.savefig(buf, format='svg')
  71. def test_logscale_subs():
  72. fig, ax = plt.subplots()
  73. ax.set_yscale('log', subs=np.array([2, 3, 4]))
  74. # force draw
  75. fig.canvas.draw()
  76. @image_comparison(['logscale_mask.png'], remove_text=True)
  77. def test_logscale_mask():
  78. # Check that zero values are masked correctly on log scales.
  79. # See github issue 8045
  80. xs = np.linspace(0, 50, 1001)
  81. fig, ax = plt.subplots()
  82. ax.plot(np.exp(-xs**2))
  83. fig.canvas.draw()
  84. ax.set(yscale="log")
  85. def test_extra_kwargs_raise():
  86. fig, ax = plt.subplots()
  87. for scale in ['linear', 'log', 'symlog']:
  88. with pytest.raises(TypeError):
  89. ax.set_yscale(scale, foo='mask')
  90. def test_logscale_invert_transform():
  91. fig, ax = plt.subplots()
  92. ax.set_yscale('log')
  93. # get transformation from data to axes
  94. tform = (ax.transAxes + ax.transData.inverted()).inverted()
  95. # direct test of log transform inversion
  96. inverted_transform = LogTransform(base=2).inverted()
  97. assert isinstance(inverted_transform, InvertedLogTransform)
  98. assert inverted_transform.base == 2
  99. def test_logscale_transform_repr():
  100. fig, ax = plt.subplots()
  101. ax.set_yscale('log')
  102. repr(ax.transData)
  103. repr(LogTransform(10, nonpositive='clip'))
  104. @image_comparison(['logscale_nonpos_values.png'],
  105. remove_text=True, tol=0.02, style='mpl20')
  106. def test_logscale_nonpos_values():
  107. np.random.seed(19680801)
  108. xs = np.random.normal(size=int(1e3))
  109. fig, ((ax1, ax2), (ax3, ax4)) = plt.subplots(2, 2)
  110. ax1.hist(xs, range=(-5, 5), bins=10)
  111. ax1.set_yscale('log')
  112. ax2.hist(xs, range=(-5, 5), bins=10)
  113. ax2.set_yscale('log', nonpositive='mask')
  114. xdata = np.arange(0, 10, 0.01)
  115. ydata = np.exp(-xdata)
  116. edata = 0.2*(10-xdata)*np.cos(5*xdata)*np.exp(-xdata)
  117. ax3.fill_between(xdata, ydata - edata, ydata + edata)
  118. ax3.set_yscale('log')
  119. x = np.logspace(-1, 1)
  120. y = x ** 3
  121. yerr = x**2
  122. ax4.errorbar(x, y, yerr=yerr)
  123. ax4.set_yscale('log')
  124. ax4.set_xscale('log')
  125. def test_invalid_log_lims():
  126. # Check that invalid log scale limits are ignored
  127. fig, ax = plt.subplots()
  128. ax.scatter(range(0, 4), range(0, 4))
  129. ax.set_xscale('log')
  130. original_xlim = ax.get_xlim()
  131. with pytest.warns(UserWarning):
  132. ax.set_xlim(left=0)
  133. assert ax.get_xlim() == original_xlim
  134. with pytest.warns(UserWarning):
  135. ax.set_xlim(right=-1)
  136. assert ax.get_xlim() == original_xlim
  137. ax.set_yscale('log')
  138. original_ylim = ax.get_ylim()
  139. with pytest.warns(UserWarning):
  140. ax.set_ylim(bottom=0)
  141. assert ax.get_ylim() == original_ylim
  142. with pytest.warns(UserWarning):
  143. ax.set_ylim(top=-1)
  144. assert ax.get_ylim() == original_ylim
  145. @image_comparison(['function_scales.png'], remove_text=True, style='mpl20')
  146. def test_function_scale():
  147. def inverse(x):
  148. return x**2
  149. def forward(x):
  150. return x**(1/2)
  151. fig, ax = plt.subplots()
  152. x = np.arange(1, 1000)
  153. ax.plot(x, x)
  154. ax.set_xscale('function', functions=(forward, inverse))
  155. ax.set_xlim(1, 1000)
  156. def test_pass_scale():
  157. # test passing a scale object works...
  158. fig, ax = plt.subplots()
  159. scale = mscale.LogScale(axis=None)
  160. ax.set_xscale(scale)
  161. scale = mscale.LogScale(axis=None)
  162. ax.set_yscale(scale)
  163. assert ax.xaxis.get_scale() == 'log'
  164. assert ax.yaxis.get_scale() == 'log'
  165. def test_scale_deepcopy():
  166. sc = mscale.LogScale(axis='x', base=10)
  167. sc2 = copy.deepcopy(sc)
  168. assert str(sc.get_transform()) == str(sc2.get_transform())
  169. assert sc._transform is not sc2._transform
  170. class TestAsinhScale:
  171. def test_transforms(self):
  172. a0 = 17.0
  173. a = np.linspace(-50, 50, 100)
  174. forward = AsinhTransform(a0)
  175. inverse = forward.inverted()
  176. invinv = inverse.inverted()
  177. a_forward = forward.transform_non_affine(a)
  178. a_inverted = inverse.transform_non_affine(a_forward)
  179. assert_allclose(a_inverted, a)
  180. a_invinv = invinv.transform_non_affine(a)
  181. assert_allclose(a_invinv, a0 * np.arcsinh(a / a0))
  182. def test_init(self):
  183. fig, ax = plt.subplots()
  184. s = AsinhScale(axis=None, linear_width=23.0)
  185. assert s.linear_width == 23
  186. assert s._base == 10
  187. assert s._subs == (2, 5)
  188. tx = s.get_transform()
  189. assert isinstance(tx, AsinhTransform)
  190. assert tx.linear_width == s.linear_width
  191. def test_base_init(self):
  192. fig, ax = plt.subplots()
  193. s3 = AsinhScale(axis=None, base=3)
  194. assert s3._base == 3
  195. assert s3._subs == (2,)
  196. s7 = AsinhScale(axis=None, base=7, subs=(2, 4))
  197. assert s7._base == 7
  198. assert s7._subs == (2, 4)
  199. def test_fmtloc(self):
  200. class DummyAxis:
  201. def __init__(self):
  202. self.fields = {}
  203. def set(self, **kwargs):
  204. self.fields.update(**kwargs)
  205. def set_major_formatter(self, f):
  206. self.fields['major_formatter'] = f
  207. ax0 = DummyAxis()
  208. s0 = AsinhScale(axis=ax0, base=0)
  209. s0.set_default_locators_and_formatters(ax0)
  210. assert isinstance(ax0.fields['major_locator'], AsinhLocator)
  211. assert isinstance(ax0.fields['major_formatter'], str)
  212. ax5 = DummyAxis()
  213. s7 = AsinhScale(axis=ax5, base=5)
  214. s7.set_default_locators_and_formatters(ax5)
  215. assert isinstance(ax5.fields['major_locator'], AsinhLocator)
  216. assert isinstance(ax5.fields['major_formatter'],
  217. LogFormatterSciNotation)
  218. def test_bad_scale(self):
  219. fig, ax = plt.subplots()
  220. with pytest.raises(ValueError):
  221. AsinhScale(axis=None, linear_width=0)
  222. with pytest.raises(ValueError):
  223. AsinhScale(axis=None, linear_width=-1)
  224. s0 = AsinhScale(axis=None, )
  225. s1 = AsinhScale(axis=None, linear_width=3.0)