_subplots.py 10.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255
  1. import functools
  2. import uuid
  3. from matplotlib import cbook, docstring
  4. import matplotlib.artist as martist
  5. from matplotlib.axes._axes import Axes
  6. from matplotlib.gridspec import GridSpec, SubplotSpec
  7. import matplotlib._layoutbox as layoutbox
  8. class SubplotBase:
  9. """
  10. Base class for subplots, which are :class:`Axes` instances with
  11. additional methods to facilitate generating and manipulating a set
  12. of :class:`Axes` within a figure.
  13. """
  14. def __init__(self, fig, *args, **kwargs):
  15. """
  16. Parameters
  17. ----------
  18. fig : `matplotlib.figure.Figure`
  19. *args : tuple (*nrows*, *ncols*, *index*) or int
  20. The array of subplots in the figure has dimensions ``(nrows,
  21. ncols)``, and *index* is the index of the subplot being created.
  22. *index* starts at 1 in the upper left corner and increases to the
  23. right.
  24. If *nrows*, *ncols*, and *index* are all single digit numbers, then
  25. *args* can be passed as a single 3-digit number (e.g. 234 for
  26. (2, 3, 4)).
  27. """
  28. self.figure = fig
  29. if len(args) == 1:
  30. if isinstance(args[0], SubplotSpec):
  31. self._subplotspec = args[0]
  32. else:
  33. try:
  34. s = str(int(args[0]))
  35. rows, cols, num = map(int, s)
  36. except ValueError:
  37. raise ValueError('Single argument to subplot must be '
  38. 'a 3-digit integer')
  39. self._subplotspec = GridSpec(rows, cols,
  40. figure=self.figure)[num - 1]
  41. # num - 1 for converting from MATLAB to python indexing
  42. elif len(args) == 3:
  43. rows, cols, num = args
  44. rows = int(rows)
  45. cols = int(cols)
  46. if rows <= 0:
  47. raise ValueError(f'Number of rows must be > 0, not {rows}')
  48. if cols <= 0:
  49. raise ValueError(f'Number of columns must be > 0, not {cols}')
  50. if isinstance(num, tuple) and len(num) == 2:
  51. num = [int(n) for n in num]
  52. self._subplotspec = GridSpec(
  53. rows, cols,
  54. figure=self.figure)[(num[0] - 1):num[1]]
  55. else:
  56. if num < 1 or num > rows*cols:
  57. raise ValueError(
  58. f"num must be 1 <= num <= {rows*cols}, not {num}")
  59. self._subplotspec = GridSpec(
  60. rows, cols, figure=self.figure)[int(num) - 1]
  61. # num - 1 for converting from MATLAB to python indexing
  62. else:
  63. raise ValueError(f'Illegal argument(s) to subplot: {args}')
  64. self.update_params()
  65. # _axes_class is set in the subplot_class_factory
  66. self._axes_class.__init__(self, fig, self.figbox, **kwargs)
  67. # add a layout box to this, for both the full axis, and the poss
  68. # of the axis. We need both because the axes may become smaller
  69. # due to parasitic axes and hence no longer fill the subplotspec.
  70. if self._subplotspec._layoutbox is None:
  71. self._layoutbox = None
  72. self._poslayoutbox = None
  73. else:
  74. name = self._subplotspec._layoutbox.name + '.ax'
  75. name = name + layoutbox.seq_id()
  76. self._layoutbox = layoutbox.LayoutBox(
  77. parent=self._subplotspec._layoutbox,
  78. name=name,
  79. artist=self)
  80. self._poslayoutbox = layoutbox.LayoutBox(
  81. parent=self._layoutbox,
  82. name=self._layoutbox.name+'.pos',
  83. pos=True, subplot=True, artist=self)
  84. def __reduce__(self):
  85. # get the first axes class which does not inherit from a subplotbase
  86. axes_class = next(
  87. c for c in type(self).__mro__
  88. if issubclass(c, Axes) and not issubclass(c, SubplotBase))
  89. return (_picklable_subplot_class_constructor,
  90. (axes_class,),
  91. self.__getstate__())
  92. def get_geometry(self):
  93. """Get the subplot geometry, e.g., (2, 2, 3)."""
  94. rows, cols, num1, num2 = self.get_subplotspec().get_geometry()
  95. return rows, cols, num1 + 1 # for compatibility
  96. # COVERAGE NOTE: Never used internally or from examples
  97. def change_geometry(self, numrows, numcols, num):
  98. """Change subplot geometry, e.g., from (1, 1, 1) to (2, 2, 3)."""
  99. self._subplotspec = GridSpec(numrows, numcols,
  100. figure=self.figure)[num - 1]
  101. self.update_params()
  102. self.set_position(self.figbox)
  103. def get_subplotspec(self):
  104. """get the SubplotSpec instance associated with the subplot"""
  105. return self._subplotspec
  106. def set_subplotspec(self, subplotspec):
  107. """set the SubplotSpec instance associated with the subplot"""
  108. self._subplotspec = subplotspec
  109. def get_gridspec(self):
  110. """get the GridSpec instance associated with the subplot"""
  111. return self._subplotspec.get_gridspec()
  112. def update_params(self):
  113. """update the subplot position from fig.subplotpars"""
  114. self.figbox, _, _, self.numRows, self.numCols = \
  115. self.get_subplotspec().get_position(self.figure,
  116. return_all=True)
  117. @cbook.deprecated("3.2", alternative="ax.get_subplotspec().rowspan.start")
  118. @property
  119. def rowNum(self):
  120. return self.get_subplotspec().rowspan.start
  121. @cbook.deprecated("3.2", alternative="ax.get_subplotspec().colspan.start")
  122. @property
  123. def colNum(self):
  124. return self.get_subplotspec().colspan.start
  125. def is_first_row(self):
  126. return self.get_subplotspec().rowspan.start == 0
  127. def is_last_row(self):
  128. return self.get_subplotspec().rowspan.stop == self.get_gridspec().nrows
  129. def is_first_col(self):
  130. return self.get_subplotspec().colspan.start == 0
  131. def is_last_col(self):
  132. return self.get_subplotspec().colspan.stop == self.get_gridspec().ncols
  133. def label_outer(self):
  134. """
  135. Only show "outer" labels and tick labels.
  136. x-labels are only kept for subplots on the last row; y-labels only for
  137. subplots on the first column.
  138. """
  139. lastrow = self.is_last_row()
  140. firstcol = self.is_first_col()
  141. if not lastrow:
  142. for label in self.get_xticklabels(which="both"):
  143. label.set_visible(False)
  144. self.get_xaxis().get_offset_text().set_visible(False)
  145. self.set_xlabel("")
  146. if not firstcol:
  147. for label in self.get_yticklabels(which="both"):
  148. label.set_visible(False)
  149. self.get_yaxis().get_offset_text().set_visible(False)
  150. self.set_ylabel("")
  151. def _make_twin_axes(self, *args, **kwargs):
  152. """Make a twinx axes of self. This is used for twinx and twiny."""
  153. if 'sharex' in kwargs and 'sharey' in kwargs:
  154. # The following line is added in v2.2 to avoid breaking Seaborn,
  155. # which currently uses this internal API.
  156. if kwargs["sharex"] is not self and kwargs["sharey"] is not self:
  157. raise ValueError("Twinned Axes may share only one axis")
  158. # The dance here with label is to force add_subplot() to create a new
  159. # Axes (by passing in a label never seen before). Note that this does
  160. # not affect plot reactivation by subplot() as twin axes can never be
  161. # reactivated by subplot().
  162. sentinel = str(uuid.uuid4())
  163. real_label = kwargs.pop("label", sentinel)
  164. twin = self.figure.add_subplot(
  165. self.get_subplotspec(), *args, label=sentinel, **kwargs)
  166. if real_label is not sentinel:
  167. twin.set_label(real_label)
  168. self.set_adjustable('datalim')
  169. twin.set_adjustable('datalim')
  170. if self._layoutbox is not None and twin._layoutbox is not None:
  171. # make the layout boxes be explicitly the same
  172. twin._layoutbox.constrain_same(self._layoutbox)
  173. twin._poslayoutbox.constrain_same(self._poslayoutbox)
  174. self._twinned_axes.join(self, twin)
  175. return twin
  176. # this here to support cartopy which was using a private part of the
  177. # API to register their Axes subclasses.
  178. # In 3.1 this should be changed to a dict subclass that warns on use
  179. # In 3.3 to a dict subclass that raises a useful exception on use
  180. # In 3.4 should be removed
  181. # The slow timeline is to give cartopy enough time to get several
  182. # release out before we break them.
  183. _subplot_classes = {}
  184. @functools.lru_cache(None)
  185. def subplot_class_factory(axes_class=None):
  186. """
  187. This makes a new class that inherits from `.SubplotBase` and the
  188. given axes_class (which is assumed to be a subclass of `.axes.Axes`).
  189. This is perhaps a little bit roundabout to make a new class on
  190. the fly like this, but it means that a new Subplot class does
  191. not have to be created for every type of Axes.
  192. """
  193. if axes_class is None:
  194. axes_class = Axes
  195. try:
  196. # Avoid creating two different instances of GeoAxesSubplot...
  197. # Only a temporary backcompat fix. This should be removed in
  198. # 3.4
  199. return next(cls for cls in SubplotBase.__subclasses__()
  200. if cls.__bases__ == (SubplotBase, axes_class))
  201. except StopIteration:
  202. return type("%sSubplot" % axes_class.__name__,
  203. (SubplotBase, axes_class),
  204. {'_axes_class': axes_class})
  205. # This is provided for backward compatibility
  206. Subplot = subplot_class_factory()
  207. def _picklable_subplot_class_constructor(axes_class):
  208. """
  209. This stub class exists to return the appropriate subplot class when called
  210. with an axes class. This is purely to allow pickling of Axes and Subplots.
  211. """
  212. subplot_class = subplot_class_factory(axes_class)
  213. return subplot_class.__new__(subplot_class)
  214. docstring.interpd.update(Axes=martist.kwdoc(Axes))
  215. docstring.dedent_interpd(Axes.__init__)
  216. docstring.interpd.update(Subplot=martist.kwdoc(Axes))