mpl_renderer.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532
  1. from __future__ import annotations
  2. from collections.abc import Sequence
  3. import io
  4. from typing import TYPE_CHECKING, Any, cast
  5. import matplotlib.collections as mcollections
  6. import matplotlib.pyplot as plt
  7. import numpy as np
  8. from contourpy import FillType, LineType
  9. from contourpy.convert import convert_filled, convert_lines
  10. from contourpy.enum_util import as_fill_type, as_line_type
  11. from contourpy.util.mpl_util import filled_to_mpl_paths, lines_to_mpl_paths
  12. from contourpy.util.renderer import Renderer
  13. if TYPE_CHECKING:
  14. from matplotlib.axes import Axes
  15. from matplotlib.figure import Figure
  16. from numpy.typing import ArrayLike
  17. import contourpy._contourpy as cpy
  18. class MplRenderer(Renderer):
  19. """Utility renderer using Matplotlib to render a grid of plots over the same (x, y) range.
  20. Args:
  21. nrows (int, optional): Number of rows of plots, default ``1``.
  22. ncols (int, optional): Number of columns of plots, default ``1``.
  23. figsize (tuple(float, float), optional): Figure size in inches, default ``(9, 9)``.
  24. show_frame (bool, optional): Whether to show frame and axes ticks, default ``True``.
  25. backend (str, optional): Matplotlib backend to use or ``None`` for default backend.
  26. Default ``None``.
  27. gridspec_kw (dict, optional): Gridspec keyword arguments to pass to ``plt.subplots``,
  28. default None.
  29. """
  30. _axes: Sequence[Axes]
  31. _fig: Figure
  32. _want_tight: bool
  33. def __init__(
  34. self,
  35. nrows: int = 1,
  36. ncols: int = 1,
  37. figsize: tuple[float, float] = (9, 9),
  38. show_frame: bool = True,
  39. backend: str | None = None,
  40. gridspec_kw: dict[str, Any] | None = None,
  41. ) -> None:
  42. if backend is not None:
  43. import matplotlib
  44. matplotlib.use(backend)
  45. kwargs: dict[str, Any] = dict(figsize=figsize, squeeze=False, sharex=True, sharey=True)
  46. if gridspec_kw is not None:
  47. kwargs["gridspec_kw"] = gridspec_kw
  48. else:
  49. kwargs["subplot_kw"] = dict(aspect="equal")
  50. self._fig, axes = plt.subplots(nrows, ncols, **kwargs)
  51. self._axes = axes.flatten()
  52. if not show_frame:
  53. for ax in self._axes:
  54. ax.axis("off")
  55. self._want_tight = True
  56. def __del__(self) -> None:
  57. if hasattr(self, "_fig"):
  58. plt.close(self._fig)
  59. def _autoscale(self) -> None:
  60. # Using axes._need_autoscale attribute if need to autoscale before rendering after adding
  61. # lines/filled. Only want to autoscale once per axes regardless of how many lines/filled
  62. # added.
  63. for ax in self._axes:
  64. if getattr(ax, "_need_autoscale", False):
  65. ax.autoscale_view(tight=True)
  66. ax._need_autoscale = False # type: ignore[attr-defined]
  67. if self._want_tight and len(self._axes) > 1:
  68. self._fig.tight_layout()
  69. def _get_ax(self, ax: Axes | int) -> Axes:
  70. if isinstance(ax, int):
  71. ax = self._axes[ax]
  72. return ax
  73. def filled(
  74. self,
  75. filled: cpy.FillReturn,
  76. fill_type: FillType | str,
  77. ax: Axes | int = 0,
  78. color: str = "C0",
  79. alpha: float = 0.7,
  80. ) -> None:
  81. """Plot filled contours on a single Axes.
  82. Args:
  83. filled (sequence of arrays): Filled contour data as returned by
  84. :func:`~contourpy.ContourGenerator.filled`.
  85. fill_type (FillType or str): Type of ``filled`` data as returned by
  86. :attr:`~contourpy.ContourGenerator.fill_type`, or string equivalent
  87. ax (int or Maplotlib Axes, optional): Which axes to plot on, default ``0``.
  88. color (str, optional): Color to plot with. May be a string color or the letter ``"C"``
  89. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  90. ``tab10`` colormap. Default ``"C0"``.
  91. alpha (float, optional): Opacity to plot with, default ``0.7``.
  92. """
  93. fill_type = as_fill_type(fill_type)
  94. ax = self._get_ax(ax)
  95. paths = filled_to_mpl_paths(filled, fill_type)
  96. collection = mcollections.PathCollection(
  97. paths, facecolors=color, edgecolors="none", lw=0, alpha=alpha)
  98. ax.add_collection(collection)
  99. ax._need_autoscale = True # type: ignore[attr-defined]
  100. def grid(
  101. self,
  102. x: ArrayLike,
  103. y: ArrayLike,
  104. ax: Axes | int = 0,
  105. color: str = "black",
  106. alpha: float = 0.1,
  107. point_color: str | None = None,
  108. quad_as_tri_alpha: float = 0,
  109. ) -> None:
  110. """Plot quad grid lines on a single Axes.
  111. Args:
  112. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  113. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  114. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  115. color (str, optional): Color to plot grid lines, default ``"black"``.
  116. alpha (float, optional): Opacity to plot lines with, default ``0.1``.
  117. point_color (str, optional): Color to plot grid points or ``None`` if grid points
  118. should not be plotted, default ``None``.
  119. quad_as_tri_alpha (float, optional): Opacity to plot ``quad_as_tri`` grid, default 0.
  120. Colors may be a string color or the letter ``"C"`` followed by an integer in the range
  121. ``"C0"`` to ``"C9"`` to use a color from the ``tab10`` colormap.
  122. Warning:
  123. ``quad_as_tri_alpha > 0`` plots all quads as though they are unmasked.
  124. """
  125. ax = self._get_ax(ax)
  126. x, y = self._grid_as_2d(x, y)
  127. kwargs: dict[str, Any] = dict(color=color, alpha=alpha)
  128. ax.plot(x, y, x.T, y.T, **kwargs)
  129. if quad_as_tri_alpha > 0:
  130. # Assumes no quad mask.
  131. xmid = 0.25*(x[:-1, :-1] + x[1:, :-1] + x[:-1, 1:] + x[1:, 1:])
  132. ymid = 0.25*(y[:-1, :-1] + y[1:, :-1] + y[:-1, 1:] + y[1:, 1:])
  133. kwargs["alpha"] = quad_as_tri_alpha
  134. ax.plot(
  135. np.stack((x[:-1, :-1], xmid, x[1:, 1:])).reshape((3, -1)),
  136. np.stack((y[:-1, :-1], ymid, y[1:, 1:])).reshape((3, -1)),
  137. np.stack((x[1:, :-1], xmid, x[:-1, 1:])).reshape((3, -1)),
  138. np.stack((y[1:, :-1], ymid, y[:-1, 1:])).reshape((3, -1)),
  139. **kwargs)
  140. if point_color is not None:
  141. ax.plot(x, y, color=point_color, alpha=alpha, marker="o", lw=0)
  142. ax._need_autoscale = True # type: ignore[attr-defined]
  143. def lines(
  144. self,
  145. lines: cpy.LineReturn,
  146. line_type: LineType | str,
  147. ax: Axes | int = 0,
  148. color: str = "C0",
  149. alpha: float = 1.0,
  150. linewidth: float = 1,
  151. ) -> None:
  152. """Plot contour lines on a single Axes.
  153. Args:
  154. lines (sequence of arrays): Contour line data as returned by
  155. :func:`~contourpy.ContourGenerator.lines`.
  156. line_type (LineType or str): Type of ``lines`` data as returned by
  157. :attr:`~contourpy.ContourGenerator.line_type`, or string equivalent.
  158. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  159. color (str, optional): Color to plot lines. May be a string color or the letter ``"C"``
  160. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  161. ``tab10`` colormap. Default ``"C0"``.
  162. alpha (float, optional): Opacity to plot lines with, default ``1.0``.
  163. linewidth (float, optional): Width of lines, default ``1``.
  164. """
  165. line_type = as_line_type(line_type)
  166. ax = self._get_ax(ax)
  167. paths = lines_to_mpl_paths(lines, line_type)
  168. collection = mcollections.PathCollection(
  169. paths, facecolors="none", edgecolors=color, lw=linewidth, alpha=alpha)
  170. ax.add_collection(collection)
  171. ax._need_autoscale = True # type: ignore[attr-defined]
  172. def mask(
  173. self,
  174. x: ArrayLike,
  175. y: ArrayLike,
  176. z: ArrayLike | np.ma.MaskedArray[Any, Any],
  177. ax: Axes | int = 0,
  178. color: str = "black",
  179. ) -> None:
  180. """Plot masked out grid points as circles on a single Axes.
  181. Args:
  182. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  183. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  184. z (masked array of shape (ny, nx): z-values.
  185. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  186. color (str, optional): Circle color, default ``"black"``.
  187. """
  188. mask = np.ma.getmask(z) # type: ignore[no-untyped-call]
  189. if mask is np.ma.nomask:
  190. return
  191. ax = self._get_ax(ax)
  192. x, y = self._grid_as_2d(x, y)
  193. ax.plot(x[mask], y[mask], "o", c=color)
  194. def save(self, filename: str, transparent: bool = False) -> None:
  195. """Save plots to SVG or PNG file.
  196. Args:
  197. filename (str): Filename to save to.
  198. transparent (bool, optional): Whether background should be transparent, default
  199. ``False``.
  200. """
  201. self._autoscale()
  202. self._fig.savefig(filename, transparent=transparent)
  203. def save_to_buffer(self) -> io.BytesIO:
  204. """Save plots to an ``io.BytesIO`` buffer.
  205. Return:
  206. BytesIO: PNG image buffer.
  207. """
  208. self._autoscale()
  209. buf = io.BytesIO()
  210. self._fig.savefig(buf, format="png")
  211. buf.seek(0)
  212. return buf
  213. def show(self) -> None:
  214. """Show plots in an interactive window, in the usual Matplotlib manner.
  215. """
  216. self._autoscale()
  217. plt.show()
  218. def title(self, title: str, ax: Axes | int = 0, color: str | None = None) -> None:
  219. """Set the title of a single Axes.
  220. Args:
  221. title (str): Title text.
  222. ax (int or Matplotlib Axes, optional): Which Axes to set the title of, default ``0``.
  223. color (str, optional): Color to set title. May be a string color or the letter ``"C"``
  224. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  225. ``tab10`` colormap. Default is ``None`` which uses Matplotlib's default title color
  226. that depends on the stylesheet in use.
  227. """
  228. if color:
  229. self._get_ax(ax).set_title(title, color=color)
  230. else:
  231. self._get_ax(ax).set_title(title)
  232. def z_values(
  233. self,
  234. x: ArrayLike,
  235. y: ArrayLike,
  236. z: ArrayLike,
  237. ax: Axes | int = 0,
  238. color: str = "green",
  239. fmt: str = ".1f",
  240. quad_as_tri: bool = False,
  241. ) -> None:
  242. """Show ``z`` values on a single Axes.
  243. Args:
  244. x (array-like of shape (ny, nx) or (nx,)): The x-coordinates of the grid points.
  245. y (array-like of shape (ny, nx) or (ny,)): The y-coordinates of the grid points.
  246. z (array-like of shape (ny, nx): z-values.
  247. ax (int or Matplotlib Axes, optional): Which Axes to plot on, default ``0``.
  248. color (str, optional): Color of added text. May be a string color or the letter ``"C"``
  249. followed by an integer in the range ``"C0"`` to ``"C9"`` to use a color from the
  250. ``tab10`` colormap. Default ``"green"``.
  251. fmt (str, optional): Format to display z-values, default ``".1f"``.
  252. quad_as_tri (bool, optional): Whether to show z-values at the ``quad_as_tri`` centers
  253. of quads.
  254. Warning:
  255. ``quad_as_tri=True`` shows z-values for all quads, even if masked.
  256. """
  257. ax = self._get_ax(ax)
  258. x, y = self._grid_as_2d(x, y)
  259. z = np.asarray(z)
  260. ny, nx = z.shape
  261. for j in range(ny):
  262. for i in range(nx):
  263. ax.text(x[j, i], y[j, i], f"{z[j, i]:{fmt}}", ha="center", va="center",
  264. color=color, clip_on=True)
  265. if quad_as_tri:
  266. for j in range(ny-1):
  267. for i in range(nx-1):
  268. xx = np.mean(x[j:j+2, i:i+2])
  269. yy = np.mean(y[j:j+2, i:i+2])
  270. zz = np.mean(z[j:j+2, i:i+2])
  271. ax.text(xx, yy, f"{zz:{fmt}}", ha="center", va="center", color=color,
  272. clip_on=True)
  273. class MplTestRenderer(MplRenderer):
  274. """Test renderer implemented using Matplotlib.
  275. No whitespace around plots and no spines/ticks displayed.
  276. Uses Agg backend, so can only save to file/buffer, cannot call ``show()``.
  277. """
  278. def __init__(
  279. self,
  280. nrows: int = 1,
  281. ncols: int = 1,
  282. figsize: tuple[float, float] = (9, 9),
  283. ) -> None:
  284. gridspec = {
  285. "left": 0.01,
  286. "right": 0.99,
  287. "top": 0.99,
  288. "bottom": 0.01,
  289. "wspace": 0.01,
  290. "hspace": 0.01,
  291. }
  292. super().__init__(
  293. nrows, ncols, figsize, show_frame=True, backend="Agg", gridspec_kw=gridspec,
  294. )
  295. for ax in self._axes:
  296. ax.set_xmargin(0.0)
  297. ax.set_ymargin(0.0)
  298. ax.set_xticks([])
  299. ax.set_yticks([])
  300. self._want_tight = False
  301. class MplDebugRenderer(MplRenderer):
  302. """Debug renderer implemented using Matplotlib.
  303. Extends ``MplRenderer`` to add extra information to help in debugging such as markers, arrows,
  304. text, etc.
  305. """
  306. def __init__(
  307. self,
  308. nrows: int = 1,
  309. ncols: int = 1,
  310. figsize: tuple[float, float] = (9, 9),
  311. show_frame: bool = True,
  312. ) -> None:
  313. super().__init__(nrows, ncols, figsize, show_frame)
  314. def _arrow(
  315. self,
  316. ax: Axes,
  317. line_start: cpy.CoordinateArray,
  318. line_end: cpy.CoordinateArray,
  319. color: str,
  320. alpha: float,
  321. arrow_size: float,
  322. ) -> None:
  323. mid = 0.5*(line_start + line_end)
  324. along = line_end - line_start
  325. along /= np.sqrt(np.dot(along, along)) # Unit vector.
  326. right = np.asarray((along[1], -along[0]))
  327. arrow = np.stack((
  328. mid - (along*0.5 - right)*arrow_size,
  329. mid + along*0.5*arrow_size,
  330. mid - (along*0.5 + right)*arrow_size,
  331. ))
  332. ax.plot(arrow[:, 0], arrow[:, 1], "-", c=color, alpha=alpha)
  333. def filled(
  334. self,
  335. filled: cpy.FillReturn,
  336. fill_type: FillType | str,
  337. ax: Axes | int = 0,
  338. color: str = "C1",
  339. alpha: float = 0.7,
  340. line_color: str = "C0",
  341. line_alpha: float = 0.7,
  342. point_color: str = "C0",
  343. start_point_color: str = "red",
  344. arrow_size: float = 0.1,
  345. ) -> None:
  346. fill_type = as_fill_type(fill_type)
  347. super().filled(filled, fill_type, ax, color, alpha)
  348. if line_color is None and point_color is None:
  349. return
  350. ax = self._get_ax(ax)
  351. filled = convert_filled(filled, fill_type, FillType.ChunkCombinedOffset)
  352. # Lines.
  353. if line_color is not None:
  354. for points, offsets in zip(*filled):
  355. if points is None:
  356. continue
  357. for start, end in zip(offsets[:-1], offsets[1:]):
  358. xys = points[start:end]
  359. ax.plot(xys[:, 0], xys[:, 1], c=line_color, alpha=line_alpha)
  360. if arrow_size > 0.0:
  361. n = len(xys)
  362. for i in range(n-1):
  363. self._arrow(ax, xys[i], xys[i+1], line_color, line_alpha, arrow_size)
  364. # Points.
  365. if point_color is not None:
  366. for points, offsets in zip(*filled):
  367. if points is None:
  368. continue
  369. mask = np.ones(offsets[-1], dtype=bool)
  370. mask[offsets[1:]-1] = False # Exclude end points.
  371. if start_point_color is not None:
  372. start_indices = offsets[:-1]
  373. mask[start_indices] = False # Exclude start points.
  374. ax.plot(
  375. points[:, 0][mask], points[:, 1][mask], "o", c=point_color, alpha=line_alpha)
  376. if start_point_color is not None:
  377. ax.plot(points[:, 0][start_indices], points[:, 1][start_indices], "o",
  378. c=start_point_color, alpha=line_alpha)
  379. def lines(
  380. self,
  381. lines: cpy.LineReturn,
  382. line_type: LineType | str,
  383. ax: Axes | int = 0,
  384. color: str = "C0",
  385. alpha: float = 1.0,
  386. linewidth: float = 1,
  387. point_color: str = "C0",
  388. start_point_color: str = "red",
  389. arrow_size: float = 0.1,
  390. ) -> None:
  391. line_type = as_line_type(line_type)
  392. super().lines(lines, line_type, ax, color, alpha, linewidth)
  393. if arrow_size == 0.0 and point_color is None:
  394. return
  395. ax = self._get_ax(ax)
  396. separate_lines = convert_lines(lines, line_type, LineType.Separate)
  397. if TYPE_CHECKING:
  398. separate_lines = cast(cpy.LineReturn_Separate, separate_lines)
  399. if arrow_size > 0.0:
  400. for line in separate_lines:
  401. for i in range(len(line)-1):
  402. self._arrow(ax, line[i], line[i+1], color, alpha, arrow_size)
  403. if point_color is not None:
  404. for line in separate_lines:
  405. start_index = 0
  406. end_index = len(line)
  407. if start_point_color is not None:
  408. ax.plot(line[0, 0], line[0, 1], "o", c=start_point_color, alpha=alpha)
  409. start_index = 1
  410. if line[0][0] == line[-1][0] and line[0][1] == line[-1][1]:
  411. end_index -= 1
  412. ax.plot(line[start_index:end_index, 0], line[start_index:end_index, 1], "o",
  413. c=color, alpha=alpha)
  414. def point_numbers(
  415. self,
  416. x: ArrayLike,
  417. y: ArrayLike,
  418. z: ArrayLike,
  419. ax: Axes | int = 0,
  420. color: str = "red",
  421. ) -> None:
  422. ax = self._get_ax(ax)
  423. x, y = self._grid_as_2d(x, y)
  424. z = np.asarray(z)
  425. ny, nx = z.shape
  426. for j in range(ny):
  427. for i in range(nx):
  428. quad = i + j*nx
  429. ax.text(x[j, i], y[j, i], str(quad), ha="right", va="top", color=color,
  430. clip_on=True)
  431. def quad_numbers(
  432. self,
  433. x: ArrayLike,
  434. y: ArrayLike,
  435. z: ArrayLike,
  436. ax: Axes | int = 0,
  437. color: str = "blue",
  438. ) -> None:
  439. ax = self._get_ax(ax)
  440. x, y = self._grid_as_2d(x, y)
  441. z = np.asarray(z)
  442. ny, nx = z.shape
  443. for j in range(1, ny):
  444. for i in range(1, nx):
  445. quad = i + j*nx
  446. xmid = x[j-1:j+1, i-1:i+1].mean()
  447. ymid = y[j-1:j+1, i-1:i+1].mean()
  448. ax.text(xmid, ymid, str(quad), ha="center", va="center", color=color, clip_on=True)
  449. def z_levels(
  450. self,
  451. x: ArrayLike,
  452. y: ArrayLike,
  453. z: ArrayLike,
  454. lower_level: float,
  455. upper_level: float | None = None,
  456. ax: Axes | int = 0,
  457. color: str = "green",
  458. ) -> None:
  459. ax = self._get_ax(ax)
  460. x, y = self._grid_as_2d(x, y)
  461. z = np.asarray(z)
  462. ny, nx = z.shape
  463. for j in range(ny):
  464. for i in range(nx):
  465. zz = z[j, i]
  466. if upper_level is not None and zz > upper_level:
  467. z_level = 2
  468. elif zz > lower_level:
  469. z_level = 1
  470. else:
  471. z_level = 0
  472. ax.text(x[j, i], y[j, i], str(z_level), ha="left", va="bottom", color=color,
  473. clip_on=True)