123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655 |
- """
- Module consolidating common testing functions for checking plotting.
- Currently all plotting tests are marked as slow via
- ``pytestmark = pytest.mark.slow`` at the module level.
- """
- from __future__ import annotations
- import os
- from typing import (
- TYPE_CHECKING,
- Sequence,
- )
- import warnings
- import numpy as np
- from pandas.util._decorators import cache_readonly
- import pandas.util._test_decorators as td
- from pandas.core.dtypes.api import is_list_like
- import pandas as pd
- from pandas import (
- DataFrame,
- Series,
- to_datetime,
- )
- import pandas._testing as tm
- if TYPE_CHECKING:
- from matplotlib.axes import Axes
- @td.skip_if_no_mpl
- class TestPlotBase:
- """
- This is a common base class used for various plotting tests
- """
- def setup_method(self, method):
- import matplotlib as mpl
- from pandas.plotting._matplotlib import compat
- mpl.rcdefaults()
- self.start_date_to_int64 = 812419200000000000
- self.end_date_to_int64 = 819331200000000000
- self.mpl_ge_2_2_3 = compat.mpl_ge_2_2_3()
- self.mpl_ge_3_0_0 = compat.mpl_ge_3_0_0()
- self.mpl_ge_3_1_0 = compat.mpl_ge_3_1_0()
- self.mpl_ge_3_2_0 = compat.mpl_ge_3_2_0()
- self.bp_n_objects = 7
- self.polycollection_factor = 2
- self.default_figsize = (6.4, 4.8)
- self.default_tick_position = "left"
- n = 100
- with tm.RNGContext(42):
- gender = np.random.choice(["Male", "Female"], size=n)
- classroom = np.random.choice(["A", "B", "C"], size=n)
- self.hist_df = DataFrame(
- {
- "gender": gender,
- "classroom": classroom,
- "height": np.random.normal(66, 4, size=n),
- "weight": np.random.normal(161, 32, size=n),
- "category": np.random.randint(4, size=n),
- "datetime": to_datetime(
- np.random.randint(
- self.start_date_to_int64,
- self.end_date_to_int64,
- size=n,
- dtype=np.int64,
- )
- ),
- }
- )
- self.tdf = tm.makeTimeDataFrame()
- self.hexbin_df = DataFrame(
- {
- "A": np.random.uniform(size=20),
- "B": np.random.uniform(size=20),
- "C": np.arange(20) + np.random.uniform(size=20),
- }
- )
- def teardown_method(self, method):
- tm.close()
- @cache_readonly
- def plt(self):
- import matplotlib.pyplot as plt
- return plt
- @cache_readonly
- def colorconverter(self):
- import matplotlib.colors as colors
- return colors.colorConverter
- def _check_legend_labels(self, axes, labels=None, visible=True):
- """
- Check each axes has expected legend labels
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- labels : list-like
- expected legend labels
- visible : bool
- expected legend visibility. labels are checked only when visible is
- True
- """
- if visible and (labels is None):
- raise ValueError("labels must be specified when visible is True")
- axes = self._flatten_visible(axes)
- for ax in axes:
- if visible:
- assert ax.get_legend() is not None
- self._check_text_labels(ax.get_legend().get_texts(), labels)
- else:
- assert ax.get_legend() is None
- def _check_legend_marker(self, ax, expected_markers=None, visible=True):
- """
- Check ax has expected legend markers
- Parameters
- ----------
- ax : matplotlib Axes object
- expected_markers : list-like
- expected legend markers
- visible : bool
- expected legend visibility. labels are checked only when visible is
- True
- """
- if visible and (expected_markers is None):
- raise ValueError("Markers must be specified when visible is True")
- if visible:
- handles, _ = ax.get_legend_handles_labels()
- markers = [handle.get_marker() for handle in handles]
- assert markers == expected_markers
- else:
- assert ax.get_legend() is None
- def _check_data(self, xp, rs):
- """
- Check each axes has identical lines
- Parameters
- ----------
- xp : matplotlib Axes object
- rs : matplotlib Axes object
- """
- xp_lines = xp.get_lines()
- rs_lines = rs.get_lines()
- def check_line(xpl, rsl):
- xpdata = xpl.get_xydata()
- rsdata = rsl.get_xydata()
- tm.assert_almost_equal(xpdata, rsdata)
- assert len(xp_lines) == len(rs_lines)
- [check_line(xpl, rsl) for xpl, rsl in zip(xp_lines, rs_lines)]
- tm.close()
- def _check_visible(self, collections, visible=True):
- """
- Check each artist is visible or not
- Parameters
- ----------
- collections : matplotlib Artist or its list-like
- target Artist or its list or collection
- visible : bool
- expected visibility
- """
- from matplotlib.collections import Collection
- if not isinstance(collections, Collection) and not is_list_like(collections):
- collections = [collections]
- for patch in collections:
- assert patch.get_visible() == visible
- def _check_patches_all_filled(
- self, axes: Axes | Sequence[Axes], filled: bool = True
- ) -> None:
- """
- Check for each artist whether it is filled or not
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- filled : bool
- expected filling
- """
- axes = self._flatten_visible(axes)
- for ax in axes:
- for patch in ax.patches:
- assert patch.fill == filled
- def _get_colors_mapped(self, series, colors):
- unique = series.unique()
- # unique and colors length can be differed
- # depending on slice value
- mapped = dict(zip(unique, colors))
- return [mapped[v] for v in series.values]
- def _check_colors(
- self, collections, linecolors=None, facecolors=None, mapping=None
- ):
- """
- Check each artist has expected line colors and face colors
- Parameters
- ----------
- collections : list-like
- list or collection of target artist
- linecolors : list-like which has the same length as collections
- list of expected line colors
- facecolors : list-like which has the same length as collections
- list of expected face colors
- mapping : Series
- Series used for color grouping key
- used for andrew_curves, parallel_coordinates, radviz test
- """
- from matplotlib.collections import (
- Collection,
- LineCollection,
- PolyCollection,
- )
- from matplotlib.lines import Line2D
- conv = self.colorconverter
- if linecolors is not None:
- if mapping is not None:
- linecolors = self._get_colors_mapped(mapping, linecolors)
- linecolors = linecolors[: len(collections)]
- assert len(collections) == len(linecolors)
- for patch, color in zip(collections, linecolors):
- if isinstance(patch, Line2D):
- result = patch.get_color()
- # Line2D may contains string color expression
- result = conv.to_rgba(result)
- elif isinstance(patch, (PolyCollection, LineCollection)):
- result = tuple(patch.get_edgecolor()[0])
- else:
- result = patch.get_edgecolor()
- expected = conv.to_rgba(color)
- assert result == expected
- if facecolors is not None:
- if mapping is not None:
- facecolors = self._get_colors_mapped(mapping, facecolors)
- facecolors = facecolors[: len(collections)]
- assert len(collections) == len(facecolors)
- for patch, color in zip(collections, facecolors):
- if isinstance(patch, Collection):
- # returned as list of np.array
- result = patch.get_facecolor()[0]
- else:
- result = patch.get_facecolor()
- if isinstance(result, np.ndarray):
- result = tuple(result)
- expected = conv.to_rgba(color)
- assert result == expected
- def _check_text_labels(self, texts, expected):
- """
- Check each text has expected labels
- Parameters
- ----------
- texts : matplotlib Text object, or its list-like
- target text, or its list
- expected : str or list-like which has the same length as texts
- expected text label, or its list
- """
- if not is_list_like(texts):
- assert texts.get_text() == expected
- else:
- labels = [t.get_text() for t in texts]
- assert len(labels) == len(expected)
- for label, e in zip(labels, expected):
- assert label == e
- def _check_ticks_props(
- self, axes, xlabelsize=None, xrot=None, ylabelsize=None, yrot=None
- ):
- """
- Check each axes has expected tick properties
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- xlabelsize : number
- expected xticks font size
- xrot : number
- expected xticks rotation
- ylabelsize : number
- expected yticks font size
- yrot : number
- expected yticks rotation
- """
- from matplotlib.ticker import NullFormatter
- axes = self._flatten_visible(axes)
- for ax in axes:
- if xlabelsize is not None or xrot is not None:
- if isinstance(ax.xaxis.get_minor_formatter(), NullFormatter):
- # If minor ticks has NullFormatter, rot / fontsize are not
- # retained
- labels = ax.get_xticklabels()
- else:
- labels = ax.get_xticklabels() + ax.get_xticklabels(minor=True)
- for label in labels:
- if xlabelsize is not None:
- tm.assert_almost_equal(label.get_fontsize(), xlabelsize)
- if xrot is not None:
- tm.assert_almost_equal(label.get_rotation(), xrot)
- if ylabelsize is not None or yrot is not None:
- if isinstance(ax.yaxis.get_minor_formatter(), NullFormatter):
- labels = ax.get_yticklabels()
- else:
- labels = ax.get_yticklabels() + ax.get_yticklabels(minor=True)
- for label in labels:
- if ylabelsize is not None:
- tm.assert_almost_equal(label.get_fontsize(), ylabelsize)
- if yrot is not None:
- tm.assert_almost_equal(label.get_rotation(), yrot)
- def _check_ax_scales(self, axes, xaxis="linear", yaxis="linear"):
- """
- Check each axes has expected scales
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- xaxis : {'linear', 'log'}
- expected xaxis scale
- yaxis : {'linear', 'log'}
- expected yaxis scale
- """
- axes = self._flatten_visible(axes)
- for ax in axes:
- assert ax.xaxis.get_scale() == xaxis
- assert ax.yaxis.get_scale() == yaxis
- def _check_axes_shape(self, axes, axes_num=None, layout=None, figsize=None):
- """
- Check expected number of axes is drawn in expected layout
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- axes_num : number
- expected number of axes. Unnecessary axes should be set to
- invisible.
- layout : tuple
- expected layout, (expected number of rows , columns)
- figsize : tuple
- expected figsize. default is matplotlib default
- """
- from pandas.plotting._matplotlib.tools import flatten_axes
- if figsize is None:
- figsize = self.default_figsize
- visible_axes = self._flatten_visible(axes)
- if axes_num is not None:
- assert len(visible_axes) == axes_num
- for ax in visible_axes:
- # check something drawn on visible axes
- assert len(ax.get_children()) > 0
- if layout is not None:
- result = self._get_axes_layout(flatten_axes(axes))
- assert result == layout
- tm.assert_numpy_array_equal(
- visible_axes[0].figure.get_size_inches(),
- np.array(figsize, dtype=np.float64),
- )
- def _get_axes_layout(self, axes):
- x_set = set()
- y_set = set()
- for ax in axes:
- # check axes coordinates to estimate layout
- points = ax.get_position().get_points()
- x_set.add(points[0][0])
- y_set.add(points[0][1])
- return (len(y_set), len(x_set))
- def _flatten_visible(self, axes):
- """
- Flatten axes, and filter only visible
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- """
- from pandas.plotting._matplotlib.tools import flatten_axes
- axes = flatten_axes(axes)
- axes = [ax for ax in axes if ax.get_visible()]
- return axes
- def _check_has_errorbars(self, axes, xerr=0, yerr=0):
- """
- Check axes has expected number of errorbars
- Parameters
- ----------
- axes : matplotlib Axes object, or its list-like
- xerr : number
- expected number of x errorbar
- yerr : number
- expected number of y errorbar
- """
- axes = self._flatten_visible(axes)
- for ax in axes:
- containers = ax.containers
- xerr_count = 0
- yerr_count = 0
- for c in containers:
- has_xerr = getattr(c, "has_xerr", False)
- has_yerr = getattr(c, "has_yerr", False)
- if has_xerr:
- xerr_count += 1
- if has_yerr:
- yerr_count += 1
- assert xerr == xerr_count
- assert yerr == yerr_count
- def _check_box_return_type(
- self, returned, return_type, expected_keys=None, check_ax_title=True
- ):
- """
- Check box returned type is correct
- Parameters
- ----------
- returned : object to be tested, returned from boxplot
- return_type : str
- return_type passed to boxplot
- expected_keys : list-like, optional
- group labels in subplot case. If not passed,
- the function checks assuming boxplot uses single ax
- check_ax_title : bool
- Whether to check the ax.title is the same as expected_key
- Intended to be checked by calling from ``boxplot``.
- Normal ``plot`` doesn't attach ``ax.title``, it must be disabled.
- """
- from matplotlib.axes import Axes
- types = {"dict": dict, "axes": Axes, "both": tuple}
- if expected_keys is None:
- # should be fixed when the returning default is changed
- if return_type is None:
- return_type = "dict"
- assert isinstance(returned, types[return_type])
- if return_type == "both":
- assert isinstance(returned.ax, Axes)
- assert isinstance(returned.lines, dict)
- else:
- # should be fixed when the returning default is changed
- if return_type is None:
- for r in self._flatten_visible(returned):
- assert isinstance(r, Axes)
- return
- assert isinstance(returned, Series)
- assert sorted(returned.keys()) == sorted(expected_keys)
- for key, value in returned.items():
- assert isinstance(value, types[return_type])
- # check returned dict has correct mapping
- if return_type == "axes":
- if check_ax_title:
- assert value.get_title() == key
- elif return_type == "both":
- if check_ax_title:
- assert value.ax.get_title() == key
- assert isinstance(value.ax, Axes)
- assert isinstance(value.lines, dict)
- elif return_type == "dict":
- line = value["medians"][0]
- axes = line.axes
- if check_ax_title:
- assert axes.get_title() == key
- else:
- raise AssertionError
- def _check_grid_settings(self, obj, kinds, kws={}):
- # Make sure plot defaults to rcParams['axes.grid'] setting, GH 9792
- import matplotlib as mpl
- def is_grid_on():
- xticks = self.plt.gca().xaxis.get_major_ticks()
- yticks = self.plt.gca().yaxis.get_major_ticks()
- # for mpl 2.2.2, gridOn and gridline.get_visible disagree.
- # for new MPL, they are the same.
- if self.mpl_ge_3_1_0:
- xoff = all(not g.gridline.get_visible() for g in xticks)
- yoff = all(not g.gridline.get_visible() for g in yticks)
- else:
- xoff = all(not g.gridOn for g in xticks)
- yoff = all(not g.gridOn for g in yticks)
- return not (xoff and yoff)
- spndx = 1
- for kind in kinds:
- self.plt.subplot(1, 4 * len(kinds), spndx)
- spndx += 1
- mpl.rc("axes", grid=False)
- obj.plot(kind=kind, **kws)
- assert not is_grid_on()
- self.plt.subplot(1, 4 * len(kinds), spndx)
- spndx += 1
- mpl.rc("axes", grid=True)
- obj.plot(kind=kind, grid=False, **kws)
- assert not is_grid_on()
- if kind != "pie":
- self.plt.subplot(1, 4 * len(kinds), spndx)
- spndx += 1
- mpl.rc("axes", grid=True)
- obj.plot(kind=kind, **kws)
- assert is_grid_on()
- self.plt.subplot(1, 4 * len(kinds), spndx)
- spndx += 1
- mpl.rc("axes", grid=False)
- obj.plot(kind=kind, grid=True, **kws)
- assert is_grid_on()
- def _unpack_cycler(self, rcParams, field="color"):
- """
- Auxiliary function for correctly unpacking cycler after MPL >= 1.5
- """
- return [v[field] for v in rcParams["axes.prop_cycle"]]
- def _check_plot_works(f, filterwarnings="always", default_axes=False, **kwargs):
- """
- Create plot and ensure that plot return object is valid.
- Parameters
- ----------
- f : func
- Plotting function.
- filterwarnings : str
- Warnings filter.
- See https://docs.python.org/3/library/warnings.html#warning-filter
- default_axes : bool, optional
- If False (default):
- - If `ax` not in `kwargs`, then create subplot(211) and plot there
- - Create new subplot(212) and plot there as well
- - Mind special corner case for bootstrap_plot (see `_gen_two_subplots`)
- If True:
- - Simply run plotting function with kwargs provided
- - All required axes instances will be created automatically
- - It is recommended to use it when the plotting function
- creates multiple axes itself. It helps avoid warnings like
- 'UserWarning: To output multiple subplots,
- the figure containing the passed axes is being cleared'
- **kwargs
- Keyword arguments passed to the plotting function.
- Returns
- -------
- Plot object returned by the last plotting.
- """
- import matplotlib.pyplot as plt
- if default_axes:
- gen_plots = _gen_default_plot
- else:
- gen_plots = _gen_two_subplots
- ret = None
- with warnings.catch_warnings():
- warnings.simplefilter(filterwarnings)
- try:
- fig = kwargs.get("figure", plt.gcf())
- plt.clf()
- for ret in gen_plots(f, fig, **kwargs):
- tm.assert_is_valid_plot_return_object(ret)
- with tm.ensure_clean(return_filelike=True) as path:
- plt.savefig(path)
- except Exception as err:
- raise err
- finally:
- tm.close(fig)
- return ret
- def _gen_default_plot(f, fig, **kwargs):
- """
- Create plot in a default way.
- """
- yield f(**kwargs)
- def _gen_two_subplots(f, fig, **kwargs):
- """
- Create plot on two subplots forcefully created.
- """
- if "ax" not in kwargs:
- fig.add_subplot(211)
- yield f(**kwargs)
- if f is pd.plotting.bootstrap_plot:
- assert "ax" not in kwargs
- else:
- kwargs["ax"] = fig.add_subplot(212)
- yield f(**kwargs)
- def curpath():
- pth, _ = os.path.split(os.path.abspath(__file__))
- return pth
|