common.py 21 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655
  1. """
  2. Module consolidating common testing functions for checking plotting.
  3. Currently all plotting tests are marked as slow via
  4. ``pytestmark = pytest.mark.slow`` at the module level.
  5. """
  6. from __future__ import annotations
  7. import os
  8. from typing import (
  9. TYPE_CHECKING,
  10. Sequence,
  11. )
  12. import warnings
  13. import numpy as np
  14. from pandas.util._decorators import cache_readonly
  15. import pandas.util._test_decorators as td
  16. from pandas.core.dtypes.api import is_list_like
  17. import pandas as pd
  18. from pandas import (
  19. DataFrame,
  20. Series,
  21. to_datetime,
  22. )
  23. import pandas._testing as tm
  24. if TYPE_CHECKING:
  25. from matplotlib.axes import Axes
  26. @td.skip_if_no_mpl
  27. class TestPlotBase:
  28. """
  29. This is a common base class used for various plotting tests
  30. """
  31. def setup_method(self, method):
  32. import matplotlib as mpl
  33. from pandas.plotting._matplotlib import compat
  34. mpl.rcdefaults()
  35. self.start_date_to_int64 = 812419200000000000
  36. self.end_date_to_int64 = 819331200000000000
  37. self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3()
  38. self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0()
  39. self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0()
  40. self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0()
  41. self.bp_n_objects = 7
  42. self.polycollection_factor = 2
  43. self.default_figsize = (6.4, 4.8)
  44. self.default_tick_position = "left"
  45. n = 100
  46. with tm.RNGContext(42):
  47. gender = np.random.choice(["Male", "Female"], size=n)
  48. classroom = np.random.choice(["A", "B", "C"], size=n)
  49. self.hist_df = DataFrame(
  50. {
  51. "gender": gender,
  52. "classroom": classroom,
  53. "height": np.random.normal(66, 4, size=n),
  54. "weight": np.random.normal(161, 32, size=n),
  55. "category": np.random.randint(4, size=n),
  56. "datetime": to_datetime(
  57. np.random.randint(
  58. self.start_date_to_int64,
  59. self.end_date_to_int64,
  60. size=n,
  61. dtype=np.int64,
  62. )
  63. ),
  64. }
  65. )
  66. self.tdf = tm.makeTimeDataFrame()
  67. self.hexbin_df = DataFrame(
  68. {
  69. "A": np.random.uniform(size=20),
  70. "B": np.random.uniform(size=20),
  71. "C": np.arange(20) + np.random.uniform(size=20),
  72. }
  73. )
  74. def teardown_method(self, method):
  75. tm.close()
  76. @cache_readonly
  77. def plt(self):
  78. import matplotlib.pyplot as plt
  79. return plt
  80. @cache_readonly
  81. def colorconverter(self):
  82. import matplotlib.colors as colors
  83. return colors.colorConverter
  84. def _check_legend_labels(self, axes, labels=None, visible=True):
  85. """
  86. Check each axes has expected legend labels
  87. Parameters
  88. ----------
  89. axes : matplotlib Axes object, or its list-like
  90. labels : list-like
  91. expected legend labels
  92. visible : bool
  93. expected legend visibility. labels are checked only when visible is
  94. True
  95. """
  96. if visible and (labels is None):
  97. raise ValueError("labels must be specified when visible is True")
  98. axes = self._flatten_visible(axes)
  99. for ax in axes:
  100. if visible:
  101. assert ax.get_legend() is not None
  102. self._check_text_labels(ax.get_legend().get_texts(), labels)
  103. else:
  104. assert ax.get_legend() is None
  105. def _check_legend_marker(self, ax, expected_markers=None, visible=True):
  106. """
  107. Check ax has expected legend markers
  108. Parameters
  109. ----------
  110. ax : matplotlib Axes object
  111. expected_markers : list-like
  112. expected legend markers
  113. visible : bool
  114. expected legend visibility. labels are checked only when visible is
  115. True
  116. """
  117. if visible and (expected_markers is None):
  118. raise ValueError("Markers must be specified when visible is True")
  119. if visible:
  120. handles, _ = ax.get_legend_handles_labels()
  121. markers = [handle.get_marker() for handle in handles]
  122. assert markers == expected_markers
  123. else:
  124. assert ax.get_legend() is None
  125. def _check_data(self, xp, rs):
  126. """
  127. Check each axes has identical lines
  128. Parameters
  129. ----------
  130. xp : matplotlib Axes object
  131. rs : matplotlib Axes object
  132. """
  133. xp_lines = xp.get_lines()
  134. rs_lines = rs.get_lines()
  135. def check_line(xpl, rsl):
  136. xpdata = xpl.get_xydata()
  137. rsdata = rsl.get_xydata()
  138. tm.assert_almost_equal(xpdata, rsdata)
  139. assert len(xp_lines) == len(rs_lines)
  140. [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
  141. tm.close()
  142. def _check_visible(self, collections, visible=True):
  143. """
  144. Check each artist is visible or not
  145. Parameters
  146. ----------
  147. collections : matplotlib Artist or its list-like
  148. target Artist or its list or collection
  149. visible : bool
  150. expected visibility
  151. """
  152. from matplotlib.collections import Collection
  153. if not isinstance(collections, Collection) and not is_list_like(collections):
  154. collections = [collections]
  155. for patch in collections:
  156. assert patch.get_visible() == visible
  157. def _check_patches_all_filled(
  158. self, axes: Axes | Sequence[Axes], filled: bool = True
  159. ) -> None:
  160. """
  161. Check for each artist whether it is filled or not
  162. Parameters
  163. ----------
  164. axes : matplotlib Axes object, or its list-like
  165. filled : bool
  166. expected filling
  167. """
  168. axes = self._flatten_visible(axes)
  169. for ax in axes:
  170. for patch in ax.patches:
  171. assert patch.fill == filled
  172. def _get_colors_mapped(self, series, colors):
  173. unique = series.unique()
  174. # unique and colors length can be differed
  175. # depending on slice value
  176. mapped = dict(zip(unique, colors))
  177. return [mapped[v] for v in series.values]
  178. def _check_colors(
  179. self, collections, linecolors=None, facecolors=None, mapping=None
  180. ):
  181. """
  182. Check each artist has expected line colors and face colors
  183. Parameters
  184. ----------
  185. collections : list-like
  186. list or collection of target artist
  187. linecolors : list-like which has the same length as collections
  188. list of expected line colors
  189. facecolors : list-like which has the same length as collections
  190. list of expected face colors
  191. mapping : Series
  192. Series used for color grouping key
  193. used for andrew_curves, parallel_coordinates, radviz test
  194. """
  195. from matplotlib.collections import (
  196. Collection,
  197. LineCollection,
  198. PolyCollection,
  199. )
  200. from matplotlib.lines import Line2D
  201. conv = self.colorconverter
  202. if linecolors is not None:
  203. if mapping is not None:
  204. linecolors = self._get_colors_mapped(mapping, linecolors)
  205. linecolors = linecolors[: len(collections)]
  206. assert len(collections) == len(linecolors)
  207. for patch, color in zip(collections, linecolors):
  208. if isinstance(patch, Line2D):
  209. result = patch.get_color()
  210. # Line2D may contains string color expression
  211. result = conv.to_rgba(result)
  212. elif isinstance(patch, (PolyCollection, LineCollection)):
  213. result = tuple(patch.get_edgecolor()[0])
  214. else:
  215. result = patch.get_edgecolor()
  216. expected = conv.to_rgba(color)
  217. assert result == expected
  218. if facecolors is not None:
  219. if mapping is not None:
  220. facecolors = self._get_colors_mapped(mapping, facecolors)
  221. facecolors = facecolors[: len(collections)]
  222. assert len(collections) == len(facecolors)
  223. for patch, color in zip(collections, facecolors):
  224. if isinstance(patch, Collection):
  225. # returned as list of np.array
  226. result = patch.get_facecolor()[0]
  227. else:
  228. result = patch.get_facecolor()
  229. if isinstance(result, np.ndarray):
  230. result = tuple(result)
  231. expected = conv.to_rgba(color)
  232. assert result == expected
  233. def _check_text_labels(self, texts, expected):
  234. """
  235. Check each text has expected labels
  236. Parameters
  237. ----------
  238. texts : matplotlib Text object, or its list-like
  239. target text, or its list
  240. expected : str or list-like which has the same length as texts
  241. expected text label, or its list
  242. """
  243. if not is_list_like(texts):
  244. assert texts.get_text() == expected
  245. else:
  246. labels = [t.get_text() for t in texts]
  247. assert len(labels) == len(expected)
  248. for label, e in zip(labels, expected):
  249. assert label == e
  250. def _check_ticks_props(
  251. self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
  252. ):
  253. """
  254. Check each axes has expected tick properties
  255. Parameters
  256. ----------
  257. axes : matplotlib Axes object, or its list-like
  258. xlabelsize : number
  259. expected xticks font size
  260. xrot : number
  261. expected xticks rotation
  262. ylabelsize : number
  263. expected yticks font size
  264. yrot : number
  265. expected yticks rotation
  266. """
  267. from matplotlib.ticker import NullFormatter
  268. axes = self._flatten_visible(axes)
  269. for ax in axes:
  270. if xlabelsize is not None or xrot is not None:
  271. if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
  272. # If minor ticks has NullFormatter, rot / fontsize are not
  273. # retained
  274. labels = ax.get_xticklabels()
  275. else:
  276. labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
  277. for label in labels:
  278. if xlabelsize is not None:
  279. tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
  280. if xrot is not None:
  281. tm.assert_almost_equal(label.get_rotation(), xrot)
  282. if ylabelsize is not None or yrot is not None:
  283. if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
  284. labels = ax.get_yticklabels()
  285. else:
  286. labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
  287. for label in labels:
  288. if ylabelsize is not None:
  289. tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
  290. if yrot is not None:
  291. tm.assert_almost_equal(label.get_rotation(), yrot)
  292. def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
  293. """
  294. Check each axes has expected scales
  295. Parameters
  296. ----------
  297. axes : matplotlib Axes object, or its list-like
  298. xaxis : {'linear', 'log'}
  299. expected xaxis scale
  300. yaxis : {'linear', 'log'}
  301. expected yaxis scale
  302. """
  303. axes = self._flatten_visible(axes)
  304. for ax in axes:
  305. assert ax.xaxis.get_scale() == xaxis
  306. assert ax.yaxis.get_scale() == yaxis
  307. def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
  308. """
  309. Check expected number of axes is drawn in expected layout
  310. Parameters
  311. ----------
  312. axes : matplotlib Axes object, or its list-like
  313. axes_num : number
  314. expected number of axes. Unnecessary axes should be set to
  315. invisible.
  316. layout : tuple
  317. expected layout, (expected number of rows , columns)
  318. figsize : tuple
  319. expected figsize. default is matplotlib default
  320. """
  321. from pandas.plotting._matplotlib.tools import flatten_axes
  322. if figsize is None:
  323. figsize = self.default_figsize
  324. visible_axes = self._flatten_visible(axes)
  325. if axes_num is not None:
  326. assert len(visible_axes) == axes_num
  327. for ax in visible_axes:
  328. # check something drawn on visible axes
  329. assert len(ax.get_children()) > 0
  330. if layout is not None:
  331. result = self._get_axes_layout(flatten_axes(axes))
  332. assert result == layout
  333. tm.assert_numpy_array_equal(
  334. visible_axes[0].figure.get_size_inches(),
  335. np.array(figsize, dtype=np.float64),
  336. )
  337. def _get_axes_layout(self, axes):
  338. x_set = set()
  339. y_set = set()
  340. for ax in axes:
  341. # check axes coordinates to estimate layout
  342. points = ax.get_position().get_points()
  343. x_set.add(points[0][0])
  344. y_set.add(points[0][1])
  345. return (len(y_set), len(x_set))
  346. def _flatten_visible(self, axes):
  347. """
  348. Flatten axes, and filter only visible
  349. Parameters
  350. ----------
  351. axes : matplotlib Axes object, or its list-like
  352. """
  353. from pandas.plotting._matplotlib.tools import flatten_axes
  354. axes = flatten_axes(axes)
  355. axes = [ax for ax in axes if ax.get_visible()]
  356. return axes
  357. def _check_has_errorbars(self, axes, xerr=0, yerr=0):
  358. """
  359. Check axes has expected number of errorbars
  360. Parameters
  361. ----------
  362. axes : matplotlib Axes object, or its list-like
  363. xerr : number
  364. expected number of x errorbar
  365. yerr : number
  366. expected number of y errorbar
  367. """
  368. axes = self._flatten_visible(axes)
  369. for ax in axes:
  370. containers = ax.containers
  371. xerr_count = 0
  372. yerr_count = 0
  373. for c in containers:
  374. has_xerr = getattr(c, "has_xerr", False)
  375. has_yerr = getattr(c, "has_yerr", False)
  376. if has_xerr:
  377. xerr_count += 1
  378. if has_yerr:
  379. yerr_count += 1
  380. assert xerr == xerr_count
  381. assert yerr == yerr_count
  382. def _check_box_return_type(
  383. self, returned, return_type, expected_keys=None, check_ax_title=True
  384. ):
  385. """
  386. Check box returned type is correct
  387. Parameters
  388. ----------
  389. returned : object to be tested, returned from boxplot
  390. return_type : str
  391. return_type passed to boxplot
  392. expected_keys : list-like, optional
  393. group labels in subplot case. If not passed,
  394. the function checks assuming boxplot uses single ax
  395. check_ax_title : bool
  396. Whether to check the ax.title is the same as expected_key
  397. Intended to be checked by calling from ``boxplot``.
  398. Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
  399. """
  400. from matplotlib.axes import Axes
  401. types = {"dict": dict, "axes": Axes, "both": tuple}
  402. if expected_keys is None:
  403. # should be fixed when the returning default is changed
  404. if return_type is None:
  405. return_type = "dict"
  406. assert isinstance(returned, types[return_type])
  407. if return_type == "both":
  408. assert isinstance(returned.ax, Axes)
  409. assert isinstance(returned.lines, dict)
  410. else:
  411. # should be fixed when the returning default is changed
  412. if return_type is None:
  413. for r in self._flatten_visible(returned):
  414. assert isinstance(r, Axes)
  415. return
  416. assert isinstance(returned, Series)
  417. assert sorted(returned.keys()) == sorted(expected_keys)
  418. for key, value in returned.items():
  419. assert isinstance(value, types[return_type])
  420. # check returned dict has correct mapping
  421. if return_type == "axes":
  422. if check_ax_title:
  423. assert value.get_title() == key
  424. elif return_type == "both":
  425. if check_ax_title:
  426. assert value.ax.get_title() == key
  427. assert isinstance(value.ax, Axes)
  428. assert isinstance(value.lines, dict)
  429. elif return_type == "dict":
  430. line = value["medians"][0]
  431. axes = line.axes
  432. if check_ax_title:
  433. assert axes.get_title() == key
  434. else:
  435. raise AssertionError
  436. def _check_grid_settings(self, obj, kinds, kws={}):
  437. # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
  438. import matplotlib as mpl
  439. def is_grid_on():
  440. xticks = self.plt.gca().xaxis.get_major_ticks()
  441. yticks = self.plt.gca().yaxis.get_major_ticks()
  442. # for mpl 2.2.2, gridOn and gridline.get_visible disagree.
  443. # for new MPL, they are the same.
  444. if self.mpl_ge_3_1_0:
  445. xoff = all(not g.gridline.get_visible() for g in xticks)
  446. yoff = all(not g.gridline.get_visible() for g in yticks)
  447. else:
  448. xoff = all(not g.gridOn for g in xticks)
  449. yoff = all(not g.gridOn for g in yticks)
  450. return not (xoff and yoff)
  451. spndx = 1
  452. for kind in kinds:
  453. self.plt.subplot(1, 4 * len(kinds), spndx)
  454. spndx += 1
  455. mpl.rc("axes", grid=False)
  456. obj.plot(kind=kind, **kws)
  457. assert not is_grid_on()
  458. self.plt.subplot(1, 4 * len(kinds), spndx)
  459. spndx += 1
  460. mpl.rc("axes", grid=True)
  461. obj.plot(kind=kind, grid=False, **kws)
  462. assert not is_grid_on()
  463. if kind != "pie":
  464. self.plt.subplot(1, 4 * len(kinds), spndx)
  465. spndx += 1
  466. mpl.rc("axes", grid=True)
  467. obj.plot(kind=kind, **kws)
  468. assert is_grid_on()
  469. self.plt.subplot(1, 4 * len(kinds), spndx)
  470. spndx += 1
  471. mpl.rc("axes", grid=False)
  472. obj.plot(kind=kind, grid=True, **kws)
  473. assert is_grid_on()
  474. def _unpack_cycler(self, rcParams, field="color"):
  475. """
  476. Auxiliary function for correctly unpacking cycler after MPL >= 1.5
  477. """
  478. return [v[field] for v in rcParams["axes.prop_cycle"]]
  479. def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
  480. """
  481. Create plot and ensure that plot return object is valid.
  482. Parameters
  483. ----------
  484. f : func
  485. Plotting function.
  486. filterwarnings : str
  487. Warnings filter.
  488. See https://docs.python.org/3/library/warnings.html#warning-filter
  489. default_axes : bool, optional
  490. If False (default):
  491. - If `ax` not in `kwargs`, then create subplot(211) and plot there
  492. - Create new subplot(212) and plot there as well
  493. - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
  494. If True:
  495. - Simply run plotting function with kwargs provided
  496. - All required axes instances will be created automatically
  497. - It is recommended to use it when the plotting function
  498. creates multiple axes itself. It helps avoid warnings like
  499. 'UserWarning: To output multiple subplots,
  500. the figure containing the passed axes is being cleared'
  501. **kwargs
  502. Keyword arguments passed to the plotting function.
  503. Returns
  504. -------
  505. Plot object returned by the last plotting.
  506. """
  507. import matplotlib.pyplot as plt
  508. if default_axes:
  509. gen_plots = _gen_default_plot
  510. else:
  511. gen_plots = _gen_two_subplots
  512. ret = None
  513. with warnings.catch_warnings():
  514. warnings.simplefilter(filterwarnings)
  515. try:
  516. fig = kwargs.get("figure", plt.gcf())
  517. plt.clf()
  518. for ret in gen_plots(f, fig, **kwargs):
  519. tm.assert_is_valid_plot_return_object(ret)
  520. with tm.ensure_clean(return_filelike=True) as path:
  521. plt.savefig(path)
  522. except Exception as err:
  523. raise err
  524. finally:
  525. tm.close(fig)
  526. return ret
  527. def _gen_default_plot(f, fig, **kwargs):
  528. """
  529. Create plot in a default way.
  530. """
  531. yield f(**kwargs)
  532. def _gen_two_subplots(f, fig, **kwargs):
  533. """
  534. Create plot on two subplots forcefully created.
  535. """
  536. if "ax" not in kwargs:
  537. fig.add_subplot(211)
  538. yield f(**kwargs)
  539. if f is pd.plotting.bootstrap_plot:
  540. assert "ax" not in kwargs
  541. else:
  542. kwargs["ax"] = fig.add_subplot(212)
  543. yield f(**kwargs)
  544. def curpath():
  545. pth, _ = os.path.split(os.path.abspath(__file__))
  546. return pth