axislines.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648
  1. """
  2. Axislines includes modified implementation of the Axes class. The
  3. biggest difference is that the artists responsible for drawing the axis spine,
  4. ticks, ticklabels and axis labels are separated out from mpl's Axis
  5. class. Originally, this change was motivated to support curvilinear
  6. grid. Here are a few reasons that I came up with a new axes class:
  7. * "top" and "bottom" x-axis (or "left" and "right" y-axis) can have
  8. different ticks (tick locations and labels). This is not possible
  9. with the current mpl, although some twin axes trick can help.
  10. * Curvilinear grid.
  11. * angled ticks.
  12. In the new axes class, xaxis and yaxis is set to not visible by
  13. default, and new set of artist (AxisArtist) are defined to draw axis
  14. line, ticks, ticklabels and axis label. Axes.axis attribute serves as
  15. a dictionary of these artists, i.e., ax.axis["left"] is a AxisArtist
  16. instance responsible to draw left y-axis. The default Axes.axis contains
  17. "bottom", "left", "top" and "right".
  18. AxisArtist can be considered as a container artist and
  19. has following children artists which will draw ticks, labels, etc.
  20. * line
  21. * major_ticks, major_ticklabels
  22. * minor_ticks, minor_ticklabels
  23. * offsetText
  24. * label
  25. Note that these are separate artists from Axis class of the
  26. original mpl, thus most of tick-related command in the original mpl
  27. won't work, although some effort has made to work with. For example,
  28. color and markerwidth of the ax.axis["bottom"].major_ticks will follow
  29. those of Axes.xaxis unless explicitly specified.
  30. In addition to AxisArtist, the Axes will have *gridlines* attribute,
  31. which obviously draws grid lines. The gridlines needs to be separated
  32. from the axis as some gridlines can never pass any axis.
  33. """
  34. import numpy as np
  35. from matplotlib import cbook, rcParams
  36. import matplotlib.artist as martist
  37. import matplotlib.axes as maxes
  38. from matplotlib.path import Path
  39. from mpl_toolkits.axes_grid1 import mpl_axes
  40. from .axisline_style import AxislineStyle
  41. from .axis_artist import AxisArtist, GridlinesCollection
  42. class AxisArtistHelper:
  43. """
  44. AxisArtistHelper should define
  45. following method with given APIs. Note that the first axes argument
  46. will be axes attribute of the caller artist.::
  47. # LINE (spinal line?)
  48. def get_line(self, axes):
  49. # path : Path
  50. return path
  51. def get_line_transform(self, axes):
  52. # ...
  53. # trans : transform
  54. return trans
  55. # LABEL
  56. def get_label_pos(self, axes):
  57. # x, y : position
  58. return (x, y), trans
  59. def get_label_offset_transform(self,
  60. axes,
  61. pad_points, fontprops, renderer,
  62. bboxes,
  63. ):
  64. # va : vertical alignment
  65. # ha : horizontal alignment
  66. # a : angle
  67. return trans, va, ha, a
  68. # TICK
  69. def get_tick_transform(self, axes):
  70. return trans
  71. def get_tick_iterators(self, axes):
  72. # iter : iterable object that yields (c, angle, l) where
  73. # c, angle, l is position, tick angle, and label
  74. return iter_major, iter_minor
  75. """
  76. class _Base:
  77. """Base class for axis helper."""
  78. def __init__(self):
  79. self.delta1, self.delta2 = 0.00001, 0.00001
  80. def update_lim(self, axes):
  81. pass
  82. class Fixed(_Base):
  83. """Helper class for a fixed (in the axes coordinate) axis."""
  84. _default_passthru_pt = dict(left=(0, 0),
  85. right=(1, 0),
  86. bottom=(0, 0),
  87. top=(0, 1))
  88. def __init__(self, loc, nth_coord=None):
  89. """
  90. nth_coord = along which coordinate value varies
  91. in 2d, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  92. """
  93. cbook._check_in_list(["left", "right", "bottom", "top"], loc=loc)
  94. self._loc = loc
  95. if nth_coord is None:
  96. if loc in ["left", "right"]:
  97. nth_coord = 1
  98. elif loc in ["bottom", "top"]:
  99. nth_coord = 0
  100. self.nth_coord = nth_coord
  101. super().__init__()
  102. self.passthru_pt = self._default_passthru_pt[loc]
  103. _verts = np.array([[0., 0.],
  104. [1., 1.]])
  105. fixed_coord = 1 - nth_coord
  106. _verts[:, fixed_coord] = self.passthru_pt[fixed_coord]
  107. # axis line in transAxes
  108. self._path = Path(_verts)
  109. def get_nth_coord(self):
  110. return self.nth_coord
  111. # LINE
  112. def get_line(self, axes):
  113. return self._path
  114. def get_line_transform(self, axes):
  115. return axes.transAxes
  116. # LABEL
  117. def get_axislabel_transform(self, axes):
  118. return axes.transAxes
  119. def get_axislabel_pos_angle(self, axes):
  120. """
  121. label reference position in transAxes.
  122. get_label_transform() returns a transform of (transAxes+offset)
  123. """
  124. return dict(left=((0., 0.5), 90), # (position, angle_tangent)
  125. right=((1., 0.5), 90),
  126. bottom=((0.5, 0.), 0),
  127. top=((0.5, 1.), 0))[self._loc]
  128. # TICK
  129. def get_tick_transform(self, axes):
  130. return [axes.get_xaxis_transform(),
  131. axes.get_yaxis_transform()][self.nth_coord]
  132. class Floating(_Base):
  133. def __init__(self, nth_coord, value):
  134. self.nth_coord = nth_coord
  135. self._value = value
  136. super().__init__()
  137. def get_nth_coord(self):
  138. return self.nth_coord
  139. def get_line(self, axes):
  140. raise RuntimeError(
  141. "get_line method should be defined by the derived class")
  142. class AxisArtistHelperRectlinear:
  143. class Fixed(AxisArtistHelper.Fixed):
  144. def __init__(self, axes, loc, nth_coord=None):
  145. """
  146. nth_coord = along which coordinate value varies
  147. in 2d, nth_coord = 0 -> x axis, nth_coord = 1 -> y axis
  148. """
  149. super().__init__(loc, nth_coord)
  150. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  151. # TICK
  152. def get_tick_iterators(self, axes):
  153. """tick_loc, tick_angle, tick_label"""
  154. loc = self._loc
  155. if loc in ["bottom", "top"]:
  156. angle_normal, angle_tangent = 90, 0
  157. else:
  158. angle_normal, angle_tangent = 0, 90
  159. major = self.axis.major
  160. majorLocs = major.locator()
  161. majorLabels = major.formatter.format_ticks(majorLocs)
  162. minor = self.axis.minor
  163. minorLocs = minor.locator()
  164. minorLabels = minor.formatter.format_ticks(minorLocs)
  165. tick_to_axes = self.get_tick_transform(axes) - axes.transAxes
  166. def _f(locs, labels):
  167. for x, l in zip(locs, labels):
  168. c = list(self.passthru_pt) # copy
  169. c[self.nth_coord] = x
  170. # check if the tick point is inside axes
  171. c2 = tick_to_axes.transform(c)
  172. if (0 - self.delta1
  173. <= c2[self.nth_coord]
  174. <= 1 + self.delta2):
  175. yield c, angle_normal, angle_tangent, l
  176. return _f(majorLocs, majorLabels), _f(minorLocs, minorLabels)
  177. class Floating(AxisArtistHelper.Floating):
  178. def __init__(self, axes, nth_coord,
  179. passingthrough_point, axis_direction="bottom"):
  180. super().__init__(nth_coord, passingthrough_point)
  181. self._axis_direction = axis_direction
  182. self.axis = [axes.xaxis, axes.yaxis][self.nth_coord]
  183. def get_line(self, axes):
  184. _verts = np.array([[0., 0.],
  185. [1., 1.]])
  186. fixed_coord = 1 - self.nth_coord
  187. data_to_axes = axes.transData - axes.transAxes
  188. p = data_to_axes.transform([self._value, self._value])
  189. _verts[:, fixed_coord] = p[fixed_coord]
  190. return Path(_verts)
  191. def get_line_transform(self, axes):
  192. return axes.transAxes
  193. def get_axislabel_transform(self, axes):
  194. return axes.transAxes
  195. def get_axislabel_pos_angle(self, axes):
  196. """
  197. label reference position in transAxes.
  198. get_label_transform() returns a transform of (transAxes+offset)
  199. """
  200. angle = [0, 90][self.nth_coord]
  201. _verts = [0.5, 0.5]
  202. fixed_coord = 1 - self.nth_coord
  203. data_to_axes = axes.transData - axes.transAxes
  204. p = data_to_axes.transform([self._value, self._value])
  205. _verts[fixed_coord] = p[fixed_coord]
  206. if 0 <= _verts[fixed_coord] <= 1:
  207. return _verts, angle
  208. else:
  209. return None, None
  210. def get_tick_transform(self, axes):
  211. return axes.transData
  212. def get_tick_iterators(self, axes):
  213. """tick_loc, tick_angle, tick_label"""
  214. if self.nth_coord == 0:
  215. angle_normal, angle_tangent = 90, 0
  216. else:
  217. angle_normal, angle_tangent = 0, 90
  218. major = self.axis.major
  219. majorLocs = major.locator()
  220. majorLabels = major.formatter.format_ticks(majorLocs)
  221. minor = self.axis.minor
  222. minorLocs = minor.locator()
  223. minorLabels = minor.formatter.format_ticks(minorLocs)
  224. data_to_axes = axes.transData - axes.transAxes
  225. def _f(locs, labels):
  226. for x, l in zip(locs, labels):
  227. c = [self._value, self._value]
  228. c[self.nth_coord] = x
  229. c1, c2 = data_to_axes.transform(c)
  230. if (0 <= c1 <= 1 and 0 <= c2 <= 1
  231. and 0 - self.delta1
  232. <= [c1, c2][self.nth_coord]
  233. <= 1 + self.delta2):
  234. yield c, angle_normal, angle_tangent, l
  235. return _f(majorLocs, majorLabels), _f(minorLocs, minorLabels)
  236. class GridHelperBase:
  237. def __init__(self):
  238. self._force_update = True
  239. self._old_limits = None
  240. super().__init__()
  241. def update_lim(self, axes):
  242. x1, x2 = axes.get_xlim()
  243. y1, y2 = axes.get_ylim()
  244. if self._force_update or self._old_limits != (x1, x2, y1, y2):
  245. self._update(x1, x2, y1, y2)
  246. self._force_update = False
  247. self._old_limits = (x1, x2, y1, y2)
  248. def _update(self, x1, x2, y1, y2):
  249. pass
  250. def invalidate(self):
  251. self._force_update = True
  252. def valid(self):
  253. return not self._force_update
  254. def get_gridlines(self, which, axis):
  255. """
  256. Return list of grid lines as a list of paths (list of points).
  257. *which* : "major" or "minor"
  258. *axis* : "both", "x" or "y"
  259. """
  260. return []
  261. def new_gridlines(self, ax):
  262. """
  263. Create and return a new GridlineCollection instance.
  264. *which* : "major" or "minor"
  265. *axis* : "both", "x" or "y"
  266. """
  267. gridlines = GridlinesCollection(None, transform=ax.transData,
  268. colors=rcParams['grid.color'],
  269. linestyles=rcParams['grid.linestyle'],
  270. linewidths=rcParams['grid.linewidth'])
  271. ax._set_artist_props(gridlines)
  272. gridlines.set_grid_helper(self)
  273. ax.axes._set_artist_props(gridlines)
  274. # gridlines.set_clip_path(self.axes.patch)
  275. # set_clip_path need to be deferred after Axes.cla is completed.
  276. # It is done inside the cla.
  277. return gridlines
  278. class GridHelperRectlinear(GridHelperBase):
  279. def __init__(self, axes):
  280. super().__init__()
  281. self.axes = axes
  282. def new_fixed_axis(self, loc,
  283. nth_coord=None,
  284. axis_direction=None,
  285. offset=None,
  286. axes=None,
  287. ):
  288. if axes is None:
  289. cbook._warn_external(
  290. "'new_fixed_axis' explicitly requires the axes keyword.")
  291. axes = self.axes
  292. _helper = AxisArtistHelperRectlinear.Fixed(axes, loc, nth_coord)
  293. if axis_direction is None:
  294. axis_direction = loc
  295. axisline = AxisArtist(axes, _helper, offset=offset,
  296. axis_direction=axis_direction,
  297. )
  298. return axisline
  299. def new_floating_axis(self, nth_coord, value,
  300. axis_direction="bottom",
  301. axes=None,
  302. ):
  303. if axes is None:
  304. cbook._warn_external(
  305. "'new_floating_axis' explicitly requires the axes keyword.")
  306. axes = self.axes
  307. _helper = AxisArtistHelperRectlinear.Floating(
  308. axes, nth_coord, value, axis_direction)
  309. axisline = AxisArtist(axes, _helper)
  310. axisline.line.set_clip_on(True)
  311. axisline.line.set_clip_box(axisline.axes.bbox)
  312. return axisline
  313. def get_gridlines(self, which="major", axis="both"):
  314. """
  315. return list of gridline coordinates in data coordinates.
  316. *which* : "major" or "minor"
  317. *axis* : "both", "x" or "y"
  318. """
  319. gridlines = []
  320. if axis in ["both", "x"]:
  321. locs = []
  322. y1, y2 = self.axes.get_ylim()
  323. if which in ["both", "major"]:
  324. locs.extend(self.axes.xaxis.major.locator())
  325. if which in ["both", "minor"]:
  326. locs.extend(self.axes.xaxis.minor.locator())
  327. for x in locs:
  328. gridlines.append([[x, x], [y1, y2]])
  329. if axis in ["both", "y"]:
  330. x1, x2 = self.axes.get_xlim()
  331. locs = []
  332. if self.axes.yaxis._gridOnMajor:
  333. locs.extend(self.axes.yaxis.major.locator())
  334. if self.axes.yaxis._gridOnMinor:
  335. locs.extend(self.axes.yaxis.minor.locator())
  336. for y in locs:
  337. gridlines.append([[x1, x2], [y, y]])
  338. return gridlines
  339. @cbook.deprecated("3.1")
  340. class SimpleChainedObjects:
  341. def __init__(self, objects):
  342. self._objects = objects
  343. def __getattr__(self, k):
  344. _a = SimpleChainedObjects([getattr(a, k) for a in self._objects])
  345. return _a
  346. def __call__(self, *args, **kwargs):
  347. for m in self._objects:
  348. m(*args, **kwargs)
  349. class Axes(maxes.Axes):
  350. @cbook.deprecated("3.1")
  351. class AxisDict(dict):
  352. def __init__(self, axes):
  353. self.axes = axes
  354. super().__init__()
  355. def __getitem__(self, k):
  356. if isinstance(k, tuple):
  357. return SimpleChainedObjects(
  358. [dict.__getitem__(self, k1) for k1 in k])
  359. elif isinstance(k, slice):
  360. if k == slice(None):
  361. return SimpleChainedObjects(list(self.values()))
  362. else:
  363. raise ValueError("Unsupported slice")
  364. else:
  365. return dict.__getitem__(self, k)
  366. def __call__(self, *args, **kwargs):
  367. return maxes.Axes.axis(self.axes, *args, **kwargs)
  368. def __init__(self, *args, grid_helper=None, **kwargs):
  369. self._axisline_on = True
  370. self._grid_helper = (grid_helper if grid_helper
  371. else GridHelperRectlinear(self))
  372. super().__init__(*args, **kwargs)
  373. self.toggle_axisline(True)
  374. def toggle_axisline(self, b=None):
  375. if b is None:
  376. b = not self._axisline_on
  377. if b:
  378. self._axisline_on = True
  379. for s in self.spines.values():
  380. s.set_visible(False)
  381. self.xaxis.set_visible(False)
  382. self.yaxis.set_visible(False)
  383. else:
  384. self._axisline_on = False
  385. for s in self.spines.values():
  386. s.set_visible(True)
  387. self.xaxis.set_visible(True)
  388. self.yaxis.set_visible(True)
  389. def _init_axis_artists(self, axes=None):
  390. if axes is None:
  391. axes = self
  392. self._axislines = mpl_axes.Axes.AxisDict(self)
  393. new_fixed_axis = self.get_grid_helper().new_fixed_axis
  394. for loc in ["bottom", "top", "left", "right"]:
  395. self._axislines[loc] = new_fixed_axis(loc=loc, axes=axes,
  396. axis_direction=loc)
  397. for axisline in [self._axislines["top"], self._axislines["right"]]:
  398. axisline.label.set_visible(False)
  399. axisline.major_ticklabels.set_visible(False)
  400. axisline.minor_ticklabels.set_visible(False)
  401. @property
  402. def axis(self):
  403. return self._axislines
  404. def new_gridlines(self, grid_helper=None):
  405. """
  406. Create and return a new GridlineCollection instance.
  407. *which* : "major" or "minor"
  408. *axis* : "both", "x" or "y"
  409. """
  410. if grid_helper is None:
  411. grid_helper = self.get_grid_helper()
  412. gridlines = grid_helper.new_gridlines(self)
  413. return gridlines
  414. def _init_gridlines(self, grid_helper=None):
  415. # It is done inside the cla.
  416. self.gridlines = self.new_gridlines(grid_helper)
  417. def cla(self):
  418. # gridlines need to b created before cla() since cla calls grid()
  419. self._init_gridlines()
  420. super().cla()
  421. # the clip_path should be set after Axes.cla() since that's
  422. # when a patch is created.
  423. self.gridlines.set_clip_path(self.axes.patch)
  424. self._init_axis_artists()
  425. def get_grid_helper(self):
  426. return self._grid_helper
  427. def grid(self, b=None, which='major', axis="both", **kwargs):
  428. """
  429. Toggle the gridlines, and optionally set the properties of the lines.
  430. """
  431. # their are some discrepancy between the behavior of grid in
  432. # axes_grid and the original mpl's grid, because axes_grid
  433. # explicitly set the visibility of the gridlines.
  434. super().grid(b, which=which, axis=axis, **kwargs)
  435. if not self._axisline_on:
  436. return
  437. if b is None:
  438. b = (self.axes.xaxis._gridOnMinor
  439. or self.axes.xaxis._gridOnMajor
  440. or self.axes.yaxis._gridOnMinor
  441. or self.axes.yaxis._gridOnMajor)
  442. self.gridlines.set_which(which)
  443. self.gridlines.set_axis(axis)
  444. self.gridlines.set_visible(b)
  445. if len(kwargs):
  446. martist.setp(self.gridlines, **kwargs)
  447. def get_children(self):
  448. if self._axisline_on:
  449. children = [*self._axislines.values(), self.gridlines]
  450. else:
  451. children = []
  452. children.extend(super().get_children())
  453. return children
  454. def invalidate_grid_helper(self):
  455. self._grid_helper.invalidate()
  456. def new_fixed_axis(self, loc, offset=None):
  457. gh = self.get_grid_helper()
  458. axis = gh.new_fixed_axis(loc,
  459. nth_coord=None,
  460. axis_direction=None,
  461. offset=offset,
  462. axes=self,
  463. )
  464. return axis
  465. def new_floating_axis(self, nth_coord, value, axis_direction="bottom"):
  466. gh = self.get_grid_helper()
  467. axis = gh.new_floating_axis(nth_coord, value,
  468. axis_direction=axis_direction,
  469. axes=self)
  470. return axis
  471. Subplot = maxes.subplot_class_factory(Axes)
  472. class AxesZero(Axes):
  473. def _init_axis_artists(self):
  474. super()._init_axis_artists()
  475. new_floating_axis = self._grid_helper.new_floating_axis
  476. xaxis_zero = new_floating_axis(nth_coord=0,
  477. value=0.,
  478. axis_direction="bottom",
  479. axes=self)
  480. xaxis_zero.line.set_clip_path(self.patch)
  481. xaxis_zero.set_visible(False)
  482. self._axislines["xzero"] = xaxis_zero
  483. yaxis_zero = new_floating_axis(nth_coord=1,
  484. value=0.,
  485. axis_direction="left",
  486. axes=self)
  487. yaxis_zero.line.set_clip_path(self.patch)
  488. yaxis_zero.set_visible(False)
  489. self._axislines["yzero"] = yaxis_zero
  490. SubplotZero = maxes.subplot_class_factory(AxesZero)