test_figure.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489
  1. from datetime import datetime
  2. from pathlib import Path
  3. import platform
  4. from matplotlib import rcParams
  5. from matplotlib.testing.decorators import image_comparison, check_figures_equal
  6. from matplotlib.axes import Axes
  7. from matplotlib.ticker import AutoMinorLocator, FixedFormatter, ScalarFormatter
  8. import matplotlib.pyplot as plt
  9. import matplotlib.dates as mdates
  10. import matplotlib.gridspec as gridspec
  11. import numpy as np
  12. import pytest
  13. @image_comparison(['figure_align_labels'],
  14. tol={'aarch64': 0.02}.get(platform.machine(), 0.0))
  15. def test_align_labels():
  16. # Check the figure.align_labels() command
  17. fig = plt.figure(tight_layout=True)
  18. gs = gridspec.GridSpec(3, 3)
  19. ax = fig.add_subplot(gs[0, :2])
  20. ax.plot(np.arange(0, 1e6, 1000))
  21. ax.set_ylabel('Ylabel0 0')
  22. ax = fig.add_subplot(gs[0, -1])
  23. ax.plot(np.arange(0, 1e4, 100))
  24. for i in range(3):
  25. ax = fig.add_subplot(gs[1, i])
  26. ax.set_ylabel('YLabel1 %d' % i)
  27. ax.set_xlabel('XLabel1 %d' % i)
  28. if i in [0, 2]:
  29. ax.xaxis.set_label_position("top")
  30. ax.xaxis.tick_top()
  31. if i == 0:
  32. for tick in ax.get_xticklabels():
  33. tick.set_rotation(90)
  34. if i == 2:
  35. ax.yaxis.set_label_position("right")
  36. ax.yaxis.tick_right()
  37. for i in range(3):
  38. ax = fig.add_subplot(gs[2, i])
  39. ax.set_xlabel('XLabel2 %d' % (i))
  40. ax.set_ylabel('YLabel2 %d' % (i))
  41. if i == 2:
  42. ax.plot(np.arange(0, 1e4, 10))
  43. ax.yaxis.set_label_position("right")
  44. ax.yaxis.tick_right()
  45. for tick in ax.get_xticklabels():
  46. tick.set_rotation(90)
  47. fig.align_labels()
  48. def test_figure_label():
  49. # pyplot figure creation, selection and closing with figure label and
  50. # number
  51. plt.close('all')
  52. plt.figure('today')
  53. plt.figure(3)
  54. plt.figure('tomorrow')
  55. plt.figure()
  56. plt.figure(0)
  57. plt.figure(1)
  58. plt.figure(3)
  59. assert plt.get_fignums() == [0, 1, 3, 4, 5]
  60. assert plt.get_figlabels() == ['', 'today', '', 'tomorrow', '']
  61. plt.close(10)
  62. plt.close()
  63. plt.close(5)
  64. plt.close('tomorrow')
  65. assert plt.get_fignums() == [0, 1]
  66. assert plt.get_figlabels() == ['', 'today']
  67. def test_fignum_exists():
  68. # pyplot figure creation, selection and closing with fignum_exists
  69. plt.figure('one')
  70. plt.figure(2)
  71. plt.figure('three')
  72. plt.figure()
  73. assert plt.fignum_exists('one')
  74. assert plt.fignum_exists(2)
  75. assert plt.fignum_exists('three')
  76. assert plt.fignum_exists(4)
  77. plt.close('one')
  78. plt.close(4)
  79. assert not plt.fignum_exists('one')
  80. assert not plt.fignum_exists(4)
  81. def test_clf_keyword():
  82. # test if existing figure is cleared with figure() and subplots()
  83. text1 = 'A fancy plot'
  84. text2 = 'Really fancy!'
  85. fig0 = plt.figure(num=1)
  86. fig0.suptitle(text1)
  87. assert [t.get_text() for t in fig0.texts] == [text1]
  88. fig1 = plt.figure(num=1, clear=False)
  89. fig1.text(0.5, 0.5, text2)
  90. assert fig0 is fig1
  91. assert [t.get_text() for t in fig1.texts] == [text1, text2]
  92. fig2, ax2 = plt.subplots(2, 1, num=1, clear=True)
  93. assert fig0 is fig2
  94. assert [t.get_text() for t in fig2.texts] == []
  95. @image_comparison(['figure_today'])
  96. def test_figure():
  97. # named figure support
  98. fig = plt.figure('today')
  99. ax = fig.add_subplot()
  100. ax.set_title(fig.get_label())
  101. ax.plot(np.arange(5))
  102. # plot red line in a different figure.
  103. plt.figure('tomorrow')
  104. plt.plot([0, 1], [1, 0], 'r')
  105. # Return to the original; make sure the red line is not there.
  106. plt.figure('today')
  107. plt.close('tomorrow')
  108. @image_comparison(['figure_legend'])
  109. def test_figure_legend():
  110. fig, axs = plt.subplots(2)
  111. axs[0].plot([0, 1], [1, 0], label='x', color='g')
  112. axs[0].plot([0, 1], [0, 1], label='y', color='r')
  113. axs[0].plot([0, 1], [0.5, 0.5], label='y', color='k')
  114. axs[1].plot([0, 1], [1, 0], label='_y', color='r')
  115. axs[1].plot([0, 1], [0, 1], label='z', color='b')
  116. fig.legend()
  117. def test_gca():
  118. fig = plt.figure()
  119. ax1 = fig.add_axes([0, 0, 1, 1])
  120. assert fig.gca(projection='rectilinear') is ax1
  121. assert fig.gca() is ax1
  122. ax2 = fig.add_subplot(121, projection='polar')
  123. assert fig.gca() is ax2
  124. assert fig.gca(polar=True) is ax2
  125. ax3 = fig.add_subplot(122)
  126. assert fig.gca() is ax3
  127. # the final request for a polar axes will end up creating one
  128. # with a spec of 111.
  129. with pytest.warns(UserWarning):
  130. # Changing the projection will throw a warning
  131. assert fig.gca(polar=True) is not ax3
  132. assert fig.gca(polar=True) is not ax2
  133. assert fig.gca().get_geometry() == (1, 1, 1)
  134. fig.sca(ax1)
  135. assert fig.gca(projection='rectilinear') is ax1
  136. assert fig.gca() is ax1
  137. def test_add_subplot_invalid():
  138. fig = plt.figure()
  139. with pytest.raises(ValueError):
  140. fig.add_subplot(2, 0, 1)
  141. with pytest.raises(ValueError):
  142. fig.add_subplot(0, 2, 1)
  143. with pytest.raises(ValueError):
  144. fig.add_subplot(2, 2, 0)
  145. with pytest.raises(ValueError):
  146. fig.add_subplot(2, 2, 5)
  147. @image_comparison(['figure_suptitle'])
  148. def test_suptitle():
  149. fig, _ = plt.subplots()
  150. fig.suptitle('hello', color='r')
  151. fig.suptitle('title', color='g', rotation='30')
  152. def test_suptitle_fontproperties():
  153. from matplotlib.font_manager import FontProperties
  154. fig, ax = plt.subplots()
  155. fps = FontProperties(size='large', weight='bold')
  156. txt = fig.suptitle('fontprops title', fontproperties=fps)
  157. assert txt.get_fontsize() == fps.get_size_in_points()
  158. assert txt.get_weight() == fps.get_weight()
  159. @image_comparison(['alpha_background'],
  160. # only test png and svg. The PDF output appears correct,
  161. # but Ghostscript does not preserve the background color.
  162. extensions=['png', 'svg'],
  163. savefig_kwarg={'facecolor': (0, 1, 0.4),
  164. 'edgecolor': 'none'})
  165. def test_alpha():
  166. # We want an image which has a background color and an
  167. # alpha of 0.4.
  168. fig = plt.figure(figsize=[2, 1])
  169. fig.set_facecolor((0, 1, 0.4))
  170. fig.patch.set_alpha(0.4)
  171. import matplotlib.patches as mpatches
  172. fig.patches.append(mpatches.CirclePolygon([20, 20],
  173. radius=15,
  174. alpha=0.6,
  175. facecolor='red'))
  176. def test_too_many_figures():
  177. with pytest.warns(RuntimeWarning):
  178. for i in range(rcParams['figure.max_open_warning'] + 1):
  179. plt.figure()
  180. def test_iterability_axes_argument():
  181. # This is a regression test for matplotlib/matplotlib#3196. If one of the
  182. # arguments returned by _as_mpl_axes defines __getitem__ but is not
  183. # iterable, this would raise an exception. This is because we check
  184. # whether the arguments are iterable, and if so we try and convert them
  185. # to a tuple. However, the ``iterable`` function returns True if
  186. # __getitem__ is present, but some classes can define __getitem__ without
  187. # being iterable. The tuple conversion is now done in a try...except in
  188. # case it fails.
  189. class MyAxes(Axes):
  190. def __init__(self, *args, myclass=None, **kwargs):
  191. return Axes.__init__(self, *args, **kwargs)
  192. class MyClass:
  193. def __getitem__(self, item):
  194. if item != 'a':
  195. raise ValueError("item should be a")
  196. def _as_mpl_axes(self):
  197. return MyAxes, {'myclass': self}
  198. fig = plt.figure()
  199. fig.add_subplot(1, 1, 1, projection=MyClass())
  200. plt.close(fig)
  201. def test_set_fig_size():
  202. fig = plt.figure()
  203. # check figwidth
  204. fig.set_figwidth(5)
  205. assert fig.get_figwidth() == 5
  206. # check figheight
  207. fig.set_figheight(1)
  208. assert fig.get_figheight() == 1
  209. # check using set_size_inches
  210. fig.set_size_inches(2, 4)
  211. assert fig.get_figwidth() == 2
  212. assert fig.get_figheight() == 4
  213. # check using tuple to first argument
  214. fig.set_size_inches((1, 3))
  215. assert fig.get_figwidth() == 1
  216. assert fig.get_figheight() == 3
  217. def test_axes_remove():
  218. fig, axs = plt.subplots(2, 2)
  219. axs[-1, -1].remove()
  220. for ax in axs.ravel()[:-1]:
  221. assert ax in fig.axes
  222. assert axs[-1, -1] not in fig.axes
  223. assert len(fig.axes) == 3
  224. def test_figaspect():
  225. w, h = plt.figaspect(np.float64(2) / np.float64(1))
  226. assert h / w == 2
  227. w, h = plt.figaspect(2)
  228. assert h / w == 2
  229. w, h = plt.figaspect(np.zeros((1, 2)))
  230. assert h / w == 0.5
  231. w, h = plt.figaspect(np.zeros((2, 2)))
  232. assert h / w == 1
  233. @pytest.mark.parametrize('which', [None, 'both', 'major', 'minor'])
  234. def test_autofmt_xdate(which):
  235. date = ['3 Jan 2013', '4 Jan 2013', '5 Jan 2013', '6 Jan 2013',
  236. '7 Jan 2013', '8 Jan 2013', '9 Jan 2013', '10 Jan 2013',
  237. '11 Jan 2013', '12 Jan 2013', '13 Jan 2013', '14 Jan 2013']
  238. time = ['16:44:00', '16:45:00', '16:46:00', '16:47:00', '16:48:00',
  239. '16:49:00', '16:51:00', '16:52:00', '16:53:00', '16:55:00',
  240. '16:56:00', '16:57:00']
  241. angle = 60
  242. minors = [1, 2, 3, 4, 5, 6, 7]
  243. x = mdates.datestr2num(date)
  244. y = mdates.datestr2num(time)
  245. fig, ax = plt.subplots()
  246. ax.plot(x, y)
  247. ax.yaxis_date()
  248. ax.xaxis_date()
  249. ax.xaxis.set_minor_locator(AutoMinorLocator(2))
  250. ax.xaxis.set_minor_formatter(FixedFormatter(minors))
  251. fig.autofmt_xdate(0.2, angle, 'right', which)
  252. if which in ('both', 'major', None):
  253. for label in fig.axes[0].get_xticklabels(False, 'major'):
  254. assert int(label.get_rotation()) == angle
  255. if which in ('both', 'minor'):
  256. for label in fig.axes[0].get_xticklabels(True, 'minor'):
  257. assert int(label.get_rotation()) == angle
  258. @pytest.mark.style('default')
  259. def test_change_dpi():
  260. fig = plt.figure(figsize=(4, 4))
  261. fig.canvas.draw()
  262. assert fig.canvas.renderer.height == 400
  263. assert fig.canvas.renderer.width == 400
  264. fig.dpi = 50
  265. fig.canvas.draw()
  266. assert fig.canvas.renderer.height == 200
  267. assert fig.canvas.renderer.width == 200
  268. @pytest.mark.parametrize('width, height', [
  269. (1, np.nan),
  270. (0, 1),
  271. (-1, 1),
  272. (np.inf, 1)
  273. ])
  274. def test_invalid_figure_size(width, height):
  275. with pytest.raises(ValueError):
  276. plt.figure(figsize=(width, height))
  277. fig = plt.figure()
  278. with pytest.raises(ValueError):
  279. fig.set_size_inches(width, height)
  280. def test_invalid_figure_add_axes():
  281. fig = plt.figure()
  282. with pytest.raises(ValueError):
  283. fig.add_axes((.1, .1, .5, np.nan))
  284. def test_subplots_shareax_loglabels():
  285. fig, axs = plt.subplots(2, 2, sharex=True, sharey=True, squeeze=False)
  286. for ax in axs.flat:
  287. ax.plot([10, 20, 30], [10, 20, 30])
  288. ax.set_yscale("log")
  289. ax.set_xscale("log")
  290. for ax in axs[0, :]:
  291. assert 0 == len(ax.xaxis.get_ticklabels(which='both'))
  292. for ax in axs[1, :]:
  293. assert 0 < len(ax.xaxis.get_ticklabels(which='both'))
  294. for ax in axs[:, 1]:
  295. assert 0 == len(ax.yaxis.get_ticklabels(which='both'))
  296. for ax in axs[:, 0]:
  297. assert 0 < len(ax.yaxis.get_ticklabels(which='both'))
  298. def test_savefig():
  299. fig = plt.figure()
  300. msg = r"savefig\(\) takes 2 positional arguments but 3 were given"
  301. with pytest.raises(TypeError, match=msg):
  302. fig.savefig("fname1.png", "fname2.png")
  303. def test_figure_repr():
  304. fig = plt.figure(figsize=(10, 20), dpi=10)
  305. assert repr(fig) == "<Figure size 100x200 with 0 Axes>"
  306. def test_warn_cl_plus_tl():
  307. fig, ax = plt.subplots(constrained_layout=True)
  308. with pytest.warns(UserWarning):
  309. # this should warn,
  310. fig.subplots_adjust(top=0.8)
  311. assert not(fig.get_constrained_layout())
  312. @check_figures_equal(extensions=["png", "pdf"])
  313. def test_add_artist(fig_test, fig_ref):
  314. fig_test.set_dpi(100)
  315. fig_ref.set_dpi(100)
  316. fig_test.subplots()
  317. l1 = plt.Line2D([.2, .7], [.7, .7], gid='l1')
  318. l2 = plt.Line2D([.2, .7], [.8, .8], gid='l2')
  319. r1 = plt.Circle((20, 20), 100, transform=None, gid='C1')
  320. r2 = plt.Circle((.7, .5), .05, gid='C2')
  321. r3 = plt.Circle((4.5, .8), .55, transform=fig_test.dpi_scale_trans,
  322. facecolor='crimson', gid='C3')
  323. for a in [l1, l2, r1, r2, r3]:
  324. fig_test.add_artist(a)
  325. l2.remove()
  326. ax2 = fig_ref.subplots()
  327. l1 = plt.Line2D([.2, .7], [.7, .7], transform=fig_ref.transFigure,
  328. gid='l1', zorder=21)
  329. r1 = plt.Circle((20, 20), 100, transform=None, clip_on=False, zorder=20,
  330. gid='C1')
  331. r2 = plt.Circle((.7, .5), .05, transform=fig_ref.transFigure, gid='C2',
  332. zorder=20)
  333. r3 = plt.Circle((4.5, .8), .55, transform=fig_ref.dpi_scale_trans,
  334. facecolor='crimson', clip_on=False, zorder=20, gid='C3')
  335. for a in [l1, r1, r2, r3]:
  336. ax2.add_artist(a)
  337. @pytest.mark.parametrize("fmt", ["png", "pdf", "ps", "eps", "svg"])
  338. def test_fspath(fmt, tmpdir):
  339. out = Path(tmpdir, "test.{}".format(fmt))
  340. plt.savefig(out)
  341. with out.open("rb") as file:
  342. # All the supported formats include the format name (case-insensitive)
  343. # in the first 100 bytes.
  344. assert fmt.encode("ascii") in file.read(100).lower()
  345. def test_tightbbox():
  346. fig, ax = plt.subplots()
  347. ax.set_xlim(0, 1)
  348. t = ax.text(1., 0.5, 'This dangles over end')
  349. renderer = fig.canvas.get_renderer()
  350. x1Nom0 = 9.035 # inches
  351. assert abs(t.get_tightbbox(renderer).x1 - x1Nom0 * fig.dpi) < 2
  352. assert abs(ax.get_tightbbox(renderer).x1 - x1Nom0 * fig.dpi) < 2
  353. assert abs(fig.get_tightbbox(renderer).x1 - x1Nom0) < 0.05
  354. assert abs(fig.get_tightbbox(renderer).x0 - 0.679) < 0.05
  355. # now exclude t from the tight bbox so now the bbox is quite a bit
  356. # smaller
  357. t.set_in_layout(False)
  358. x1Nom = 7.333
  359. assert abs(ax.get_tightbbox(renderer).x1 - x1Nom * fig.dpi) < 2
  360. assert abs(fig.get_tightbbox(renderer).x1 - x1Nom) < 0.05
  361. t.set_in_layout(True)
  362. x1Nom = 7.333
  363. assert abs(ax.get_tightbbox(renderer).x1 - x1Nom0 * fig.dpi) < 2
  364. # test bbox_extra_artists method...
  365. assert abs(ax.get_tightbbox(renderer, bbox_extra_artists=[]).x1
  366. - x1Nom * fig.dpi) < 2
  367. def test_axes_removal():
  368. # Check that units can set the formatter after an Axes removal
  369. fig, axs = plt.subplots(1, 2, sharex=True)
  370. axs[1].remove()
  371. axs[0].plot([datetime(2000, 1, 1), datetime(2000, 2, 1)], [0, 1])
  372. assert isinstance(axs[0].xaxis.get_major_formatter(),
  373. mdates.AutoDateFormatter)
  374. # Check that manually setting the formatter, then removing Axes keeps
  375. # the set formatter.
  376. fig, axs = plt.subplots(1, 2, sharex=True)
  377. axs[1].xaxis.set_major_formatter(ScalarFormatter())
  378. axs[1].remove()
  379. axs[0].plot([datetime(2000, 1, 1), datetime(2000, 2, 1)], [0, 1])
  380. assert isinstance(axs[0].xaxis.get_major_formatter(),
  381. ScalarFormatter)
  382. def test_removed_axis():
  383. # Simple smoke test to make sure removing a shared axis works
  384. fig, axs = plt.subplots(2, sharex=True)
  385. axs[0].remove()
  386. fig.canvas.draw()