test_subplots.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285
  1. import itertools
  2. import numpy as np
  3. import pytest
  4. from matplotlib.axes import Axes, SubplotBase
  5. import matplotlib.pyplot as plt
  6. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  7. def check_shared(axs, x_shared, y_shared):
  8. """
  9. x_shared and y_shared are n x n boolean matrices; entry (i, j) indicates
  10. whether the x (or y) axes of subplots i and j should be shared.
  11. """
  12. for (i1, ax1), (i2, ax2), (i3, (name, shared)) in itertools.product(
  13. enumerate(axs),
  14. enumerate(axs),
  15. enumerate(zip("xy", [x_shared, y_shared]))):
  16. if i2 <= i1:
  17. continue
  18. assert axs[0]._shared_axes[name].joined(ax1, ax2) == shared[i1, i2], \
  19. "axes %i and %i incorrectly %ssharing %s axis" % (
  20. i1, i2, "not " if shared[i1, i2] else "", name)
  21. def check_ticklabel_visible(axs, x_visible, y_visible):
  22. """Check that the x and y ticklabel visibility is as specified."""
  23. for i, (ax, vx, vy) in enumerate(zip(axs, x_visible, y_visible)):
  24. for l in ax.get_xticklabels() + [ax.xaxis.offsetText]:
  25. assert l.get_visible() == vx, \
  26. f"Visibility of x axis #{i} is incorrectly {vx}"
  27. for l in ax.get_yticklabels() + [ax.yaxis.offsetText]:
  28. assert l.get_visible() == vy, \
  29. f"Visibility of y axis #{i} is incorrectly {vy}"
  30. # axis label "visibility" is toggled by label_outer by resetting the
  31. # label to empty, but it can also be empty to start with.
  32. if not vx:
  33. assert ax.get_xlabel() == ""
  34. if not vy:
  35. assert ax.get_ylabel() == ""
  36. def check_tick1_visible(axs, x_visible, y_visible):
  37. """
  38. Check that the x and y tick visibility is as specified.
  39. Note: This only checks the tick1line, i.e. bottom / left ticks.
  40. """
  41. for ax, visible, in zip(axs, x_visible):
  42. for tick in ax.xaxis.get_major_ticks():
  43. assert tick.tick1line.get_visible() == visible
  44. for ax, y_visible, in zip(axs, y_visible):
  45. for tick in ax.yaxis.get_major_ticks():
  46. assert tick.tick1line.get_visible() == visible
  47. def test_shared():
  48. rdim = (4, 4, 2)
  49. share = {
  50. 'all': np.ones(rdim[:2], dtype=bool),
  51. 'none': np.zeros(rdim[:2], dtype=bool),
  52. 'row': np.array([
  53. [False, True, False, False],
  54. [True, False, False, False],
  55. [False, False, False, True],
  56. [False, False, True, False]]),
  57. 'col': np.array([
  58. [False, False, True, False],
  59. [False, False, False, True],
  60. [True, False, False, False],
  61. [False, True, False, False]]),
  62. }
  63. visible = {
  64. 'x': {
  65. 'all': [False, False, True, True],
  66. 'col': [False, False, True, True],
  67. 'row': [True] * 4,
  68. 'none': [True] * 4,
  69. False: [True] * 4,
  70. True: [False, False, True, True],
  71. },
  72. 'y': {
  73. 'all': [True, False, True, False],
  74. 'col': [True] * 4,
  75. 'row': [True, False, True, False],
  76. 'none': [True] * 4,
  77. False: [True] * 4,
  78. True: [True, False, True, False],
  79. },
  80. }
  81. share[False] = share['none']
  82. share[True] = share['all']
  83. # test default
  84. f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2)
  85. axs = [a1, a2, a3, a4]
  86. check_shared(axs, share['none'], share['none'])
  87. plt.close(f)
  88. # test all option combinations
  89. ops = [False, True, 'all', 'none', 'row', 'col', 0, 1]
  90. for xo in ops:
  91. for yo in ops:
  92. f, ((a1, a2), (a3, a4)) = plt.subplots(2, 2, sharex=xo, sharey=yo)
  93. axs = [a1, a2, a3, a4]
  94. check_shared(axs, share[xo], share[yo])
  95. check_ticklabel_visible(axs, visible['x'][xo], visible['y'][yo])
  96. plt.close(f)
  97. @pytest.mark.parametrize('remove_ticks', [True, False])
  98. def test_label_outer(remove_ticks):
  99. f, axs = plt.subplots(2, 2, sharex=True, sharey=True)
  100. for ax in axs.flat:
  101. ax.set(xlabel="foo", ylabel="bar")
  102. ax.label_outer(remove_inner_ticks=remove_ticks)
  103. check_ticklabel_visible(
  104. axs.flat, [False, False, True, True], [True, False, True, False])
  105. if remove_ticks:
  106. check_tick1_visible(
  107. axs.flat, [False, False, True, True], [True, False, True, False])
  108. else:
  109. check_tick1_visible(
  110. axs.flat, [True, True, True, True], [True, True, True, True])
  111. def test_label_outer_span():
  112. fig = plt.figure()
  113. gs = fig.add_gridspec(3, 3)
  114. # +---+---+---+
  115. # | 1 | |
  116. # +---+---+---+
  117. # | | | 3 |
  118. # + 2 +---+---+
  119. # | | 4 | |
  120. # +---+---+---+
  121. a1 = fig.add_subplot(gs[0, 0:2])
  122. a2 = fig.add_subplot(gs[1:3, 0])
  123. a3 = fig.add_subplot(gs[1, 2])
  124. a4 = fig.add_subplot(gs[2, 1])
  125. for ax in fig.axes:
  126. ax.label_outer()
  127. check_ticklabel_visible(
  128. fig.axes, [False, True, False, True], [True, True, False, False])
  129. def test_label_outer_non_gridspec():
  130. ax = plt.axes((0, 0, 1, 1))
  131. ax.label_outer() # Does nothing.
  132. check_ticklabel_visible([ax], [True], [True])
  133. def test_shared_and_moved():
  134. # test if sharey is on, but then tick_left is called that labels don't
  135. # re-appear. Seaborn does this just to be sure yaxis is on left...
  136. f, (a1, a2) = plt.subplots(1, 2, sharey=True)
  137. check_ticklabel_visible([a2], [True], [False])
  138. a2.yaxis.tick_left()
  139. check_ticklabel_visible([a2], [True], [False])
  140. f, (a1, a2) = plt.subplots(2, 1, sharex=True)
  141. check_ticklabel_visible([a1], [False], [True])
  142. a2.xaxis.tick_bottom()
  143. check_ticklabel_visible([a1], [False], [True])
  144. def test_exceptions():
  145. # TODO should this test more options?
  146. with pytest.raises(ValueError):
  147. plt.subplots(2, 2, sharex='blah')
  148. with pytest.raises(ValueError):
  149. plt.subplots(2, 2, sharey='blah')
  150. @image_comparison(['subplots_offset_text'])
  151. def test_subplots_offsettext():
  152. x = np.arange(0, 1e10, 1e9)
  153. y = np.arange(0, 100, 10)+1e4
  154. fig, axs = plt.subplots(2, 2, sharex='col', sharey='all')
  155. axs[0, 0].plot(x, x)
  156. axs[1, 0].plot(x, x)
  157. axs[0, 1].plot(y, x)
  158. axs[1, 1].plot(y, x)
  159. @pytest.mark.parametrize("top", [True, False])
  160. @pytest.mark.parametrize("bottom", [True, False])
  161. @pytest.mark.parametrize("left", [True, False])
  162. @pytest.mark.parametrize("right", [True, False])
  163. def test_subplots_hide_ticklabels(top, bottom, left, right):
  164. # Ideally, we would also test offset-text visibility (and remove
  165. # test_subplots_offsettext), but currently, setting rcParams fails to move
  166. # the offset texts as well.
  167. with plt.rc_context({"xtick.labeltop": top, "xtick.labelbottom": bottom,
  168. "ytick.labelleft": left, "ytick.labelright": right}):
  169. axs = plt.figure().subplots(3, 3, sharex=True, sharey=True)
  170. for (i, j), ax in np.ndenumerate(axs):
  171. xtop = ax.xaxis._major_tick_kw["label2On"]
  172. xbottom = ax.xaxis._major_tick_kw["label1On"]
  173. yleft = ax.yaxis._major_tick_kw["label1On"]
  174. yright = ax.yaxis._major_tick_kw["label2On"]
  175. assert xtop == (top and i == 0)
  176. assert xbottom == (bottom and i == 2)
  177. assert yleft == (left and j == 0)
  178. assert yright == (right and j == 2)
  179. @pytest.mark.parametrize("xlabel_position", ["bottom", "top"])
  180. @pytest.mark.parametrize("ylabel_position", ["left", "right"])
  181. def test_subplots_hide_axislabels(xlabel_position, ylabel_position):
  182. axs = plt.figure().subplots(3, 3, sharex=True, sharey=True)
  183. for (i, j), ax in np.ndenumerate(axs):
  184. ax.set(xlabel="foo", ylabel="bar")
  185. ax.xaxis.set_label_position(xlabel_position)
  186. ax.yaxis.set_label_position(ylabel_position)
  187. ax.label_outer()
  188. assert bool(ax.get_xlabel()) == (
  189. xlabel_position == "bottom" and i == 2
  190. or xlabel_position == "top" and i == 0)
  191. assert bool(ax.get_ylabel()) == (
  192. ylabel_position == "left" and j == 0
  193. or ylabel_position == "right" and j == 2)
  194. def test_get_gridspec():
  195. # ahem, pretty trivial, but...
  196. fig, ax = plt.subplots()
  197. assert ax.get_subplotspec().get_gridspec() == ax.get_gridspec()
  198. def test_dont_mutate_kwargs():
  199. subplot_kw = {'sharex': 'all'}
  200. gridspec_kw = {'width_ratios': [1, 2]}
  201. fig, ax = plt.subplots(1, 2, subplot_kw=subplot_kw,
  202. gridspec_kw=gridspec_kw)
  203. assert subplot_kw == {'sharex': 'all'}
  204. assert gridspec_kw == {'width_ratios': [1, 2]}
  205. @pytest.mark.parametrize("width_ratios", [None, [1, 3, 2]])
  206. @pytest.mark.parametrize("height_ratios", [None, [1, 2]])
  207. @check_figures_equal(extensions=['png'])
  208. def test_width_and_height_ratios(fig_test, fig_ref,
  209. height_ratios, width_ratios):
  210. fig_test.subplots(2, 3, height_ratios=height_ratios,
  211. width_ratios=width_ratios)
  212. fig_ref.subplots(2, 3, gridspec_kw={
  213. 'height_ratios': height_ratios,
  214. 'width_ratios': width_ratios})
  215. @pytest.mark.parametrize("width_ratios", [None, [1, 3, 2]])
  216. @pytest.mark.parametrize("height_ratios", [None, [1, 2]])
  217. @check_figures_equal(extensions=['png'])
  218. def test_width_and_height_ratios_mosaic(fig_test, fig_ref,
  219. height_ratios, width_ratios):
  220. mosaic_spec = [['A', 'B', 'B'], ['A', 'C', 'D']]
  221. fig_test.subplot_mosaic(mosaic_spec, height_ratios=height_ratios,
  222. width_ratios=width_ratios)
  223. fig_ref.subplot_mosaic(mosaic_spec, gridspec_kw={
  224. 'height_ratios': height_ratios,
  225. 'width_ratios': width_ratios})
  226. @pytest.mark.parametrize('method,args', [
  227. ('subplots', (2, 3)),
  228. ('subplot_mosaic', ('abc;def', ))
  229. ]
  230. )
  231. def test_ratio_overlapping_kws(method, args):
  232. with pytest.raises(ValueError, match='height_ratios'):
  233. getattr(plt, method)(*args, height_ratios=[1, 2],
  234. gridspec_kw={'height_ratios': [1, 2]})
  235. with pytest.raises(ValueError, match='width_ratios'):
  236. getattr(plt, method)(*args, width_ratios=[1, 2, 3],
  237. gridspec_kw={'width_ratios': [1, 2, 3]})
  238. def test_old_subplot_compat():
  239. fig = plt.figure()
  240. assert isinstance(fig.add_subplot(), SubplotBase)
  241. assert not isinstance(fig.add_axes(rect=[0, 0, 1, 1]), SubplotBase)
  242. with pytest.raises(TypeError):
  243. Axes(fig, [0, 0, 1, 1], rect=[0, 0, 1, 1])