test_artist.py 17 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564
  1. import io
  2. from itertools import chain
  3. import numpy as np
  4. import pytest
  5. import matplotlib.colors as mcolors
  6. import matplotlib.pyplot as plt
  7. import matplotlib.patches as mpatches
  8. import matplotlib.lines as mlines
  9. import matplotlib.path as mpath
  10. import matplotlib.transforms as mtransforms
  11. import matplotlib.collections as mcollections
  12. import matplotlib.artist as martist
  13. import matplotlib.backend_bases as mbackend_bases
  14. import matplotlib as mpl
  15. from matplotlib.testing.decorators import check_figures_equal, image_comparison
  16. def test_patch_transform_of_none():
  17. # tests the behaviour of patches added to an Axes with various transform
  18. # specifications
  19. ax = plt.axes()
  20. ax.set_xlim(1, 3)
  21. ax.set_ylim(1, 3)
  22. # Draw an ellipse over data coord (2, 2) by specifying device coords.
  23. xy_data = (2, 2)
  24. xy_pix = ax.transData.transform(xy_data)
  25. # Not providing a transform of None puts the ellipse in data coordinates .
  26. e = mpatches.Ellipse(xy_data, width=1, height=1, fc='yellow', alpha=0.5)
  27. ax.add_patch(e)
  28. assert e._transform == ax.transData
  29. # Providing a transform of None puts the ellipse in device coordinates.
  30. e = mpatches.Ellipse(xy_pix, width=120, height=120, fc='coral',
  31. transform=None, alpha=0.5)
  32. assert e.is_transform_set()
  33. ax.add_patch(e)
  34. assert isinstance(e._transform, mtransforms.IdentityTransform)
  35. # Providing an IdentityTransform puts the ellipse in device coordinates.
  36. e = mpatches.Ellipse(xy_pix, width=100, height=100,
  37. transform=mtransforms.IdentityTransform(), alpha=0.5)
  38. ax.add_patch(e)
  39. assert isinstance(e._transform, mtransforms.IdentityTransform)
  40. # Not providing a transform, and then subsequently "get_transform" should
  41. # not mean that "is_transform_set".
  42. e = mpatches.Ellipse(xy_pix, width=120, height=120, fc='coral',
  43. alpha=0.5)
  44. intermediate_transform = e.get_transform()
  45. assert not e.is_transform_set()
  46. ax.add_patch(e)
  47. assert e.get_transform() != intermediate_transform
  48. assert e.is_transform_set()
  49. assert e._transform == ax.transData
  50. def test_collection_transform_of_none():
  51. # tests the behaviour of collections added to an Axes with various
  52. # transform specifications
  53. ax = plt.axes()
  54. ax.set_xlim(1, 3)
  55. ax.set_ylim(1, 3)
  56. # draw an ellipse over data coord (2, 2) by specifying device coords
  57. xy_data = (2, 2)
  58. xy_pix = ax.transData.transform(xy_data)
  59. # not providing a transform of None puts the ellipse in data coordinates
  60. e = mpatches.Ellipse(xy_data, width=1, height=1)
  61. c = mcollections.PatchCollection([e], facecolor='yellow', alpha=0.5)
  62. ax.add_collection(c)
  63. # the collection should be in data coordinates
  64. assert c.get_offset_transform() + c.get_transform() == ax.transData
  65. # providing a transform of None puts the ellipse in device coordinates
  66. e = mpatches.Ellipse(xy_pix, width=120, height=120)
  67. c = mcollections.PatchCollection([e], facecolor='coral',
  68. alpha=0.5)
  69. c.set_transform(None)
  70. ax.add_collection(c)
  71. assert isinstance(c.get_transform(), mtransforms.IdentityTransform)
  72. # providing an IdentityTransform puts the ellipse in device coordinates
  73. e = mpatches.Ellipse(xy_pix, width=100, height=100)
  74. c = mcollections.PatchCollection([e],
  75. transform=mtransforms.IdentityTransform(),
  76. alpha=0.5)
  77. ax.add_collection(c)
  78. assert isinstance(c.get_offset_transform(), mtransforms.IdentityTransform)
  79. @image_comparison(["clip_path_clipping"], remove_text=True)
  80. def test_clipping():
  81. exterior = mpath.Path.unit_rectangle().deepcopy()
  82. exterior.vertices *= 4
  83. exterior.vertices -= 2
  84. interior = mpath.Path.unit_circle().deepcopy()
  85. interior.vertices = interior.vertices[::-1]
  86. clip_path = mpath.Path.make_compound_path(exterior, interior)
  87. star = mpath.Path.unit_regular_star(6).deepcopy()
  88. star.vertices *= 2.6
  89. fig, (ax1, ax2) = plt.subplots(1, 2, sharex=True, sharey=True)
  90. col = mcollections.PathCollection([star], lw=5, edgecolor='blue',
  91. facecolor='red', alpha=0.7, hatch='*')
  92. col.set_clip_path(clip_path, ax1.transData)
  93. ax1.add_collection(col)
  94. patch = mpatches.PathPatch(star, lw=5, edgecolor='blue', facecolor='red',
  95. alpha=0.7, hatch='*')
  96. patch.set_clip_path(clip_path, ax2.transData)
  97. ax2.add_patch(patch)
  98. ax1.set_xlim([-3, 3])
  99. ax1.set_ylim([-3, 3])
  100. @check_figures_equal(extensions=['png'])
  101. def test_clipping_zoom(fig_test, fig_ref):
  102. # This test places the Axes and sets its limits such that the clip path is
  103. # outside the figure entirely. This should not break the clip path.
  104. ax_test = fig_test.add_axes([0, 0, 1, 1])
  105. l, = ax_test.plot([-3, 3], [-3, 3])
  106. # Explicit Path instead of a Rectangle uses clip path processing, instead
  107. # of a clip box optimization.
  108. p = mpath.Path([[0, 0], [1, 0], [1, 1], [0, 1], [0, 0]])
  109. p = mpatches.PathPatch(p, transform=ax_test.transData)
  110. l.set_clip_path(p)
  111. ax_ref = fig_ref.add_axes([0, 0, 1, 1])
  112. ax_ref.plot([-3, 3], [-3, 3])
  113. ax_ref.set(xlim=(0.5, 0.75), ylim=(0.5, 0.75))
  114. ax_test.set(xlim=(0.5, 0.75), ylim=(0.5, 0.75))
  115. def test_cull_markers():
  116. x = np.random.random(20000)
  117. y = np.random.random(20000)
  118. fig, ax = plt.subplots()
  119. ax.plot(x, y, 'k.')
  120. ax.set_xlim(2, 3)
  121. pdf = io.BytesIO()
  122. fig.savefig(pdf, format="pdf")
  123. assert len(pdf.getvalue()) < 8000
  124. svg = io.BytesIO()
  125. fig.savefig(svg, format="svg")
  126. assert len(svg.getvalue()) < 20000
  127. @image_comparison(['hatching'], remove_text=True, style='default')
  128. def test_hatching():
  129. fig, ax = plt.subplots(1, 1)
  130. # Default hatch color.
  131. rect1 = mpatches.Rectangle((0, 0), 3, 4, hatch='/')
  132. ax.add_patch(rect1)
  133. rect2 = mcollections.RegularPolyCollection(
  134. 4, sizes=[16000], offsets=[(1.5, 6.5)], offset_transform=ax.transData,
  135. hatch='/')
  136. ax.add_collection(rect2)
  137. # Ensure edge color is not applied to hatching.
  138. rect3 = mpatches.Rectangle((4, 0), 3, 4, hatch='/', edgecolor='C1')
  139. ax.add_patch(rect3)
  140. rect4 = mcollections.RegularPolyCollection(
  141. 4, sizes=[16000], offsets=[(5.5, 6.5)], offset_transform=ax.transData,
  142. hatch='/', edgecolor='C1')
  143. ax.add_collection(rect4)
  144. ax.set_xlim(0, 7)
  145. ax.set_ylim(0, 9)
  146. def test_remove():
  147. fig, ax = plt.subplots()
  148. im = ax.imshow(np.arange(36).reshape(6, 6))
  149. ln, = ax.plot(range(5))
  150. assert fig.stale
  151. assert ax.stale
  152. fig.canvas.draw()
  153. assert not fig.stale
  154. assert not ax.stale
  155. assert not ln.stale
  156. assert im in ax._mouseover_set
  157. assert ln not in ax._mouseover_set
  158. assert im.axes is ax
  159. im.remove()
  160. ln.remove()
  161. for art in [im, ln]:
  162. assert art.axes is None
  163. assert art.figure is None
  164. assert im not in ax._mouseover_set
  165. assert fig.stale
  166. assert ax.stale
  167. @image_comparison(["default_edges.png"], remove_text=True, style='default')
  168. def test_default_edges():
  169. # Remove this line when this test image is regenerated.
  170. plt.rcParams['text.kerning_factor'] = 6
  171. fig, [[ax1, ax2], [ax3, ax4]] = plt.subplots(2, 2)
  172. ax1.plot(np.arange(10), np.arange(10), 'x',
  173. np.arange(10) + 1, np.arange(10), 'o')
  174. ax2.bar(np.arange(10), np.arange(10), align='edge')
  175. ax3.text(0, 0, "BOX", size=24, bbox=dict(boxstyle='sawtooth'))
  176. ax3.set_xlim((-1, 1))
  177. ax3.set_ylim((-1, 1))
  178. pp1 = mpatches.PathPatch(
  179. mpath.Path([(0, 0), (1, 0), (1, 1), (0, 0)],
  180. [mpath.Path.MOVETO, mpath.Path.CURVE3,
  181. mpath.Path.CURVE3, mpath.Path.CLOSEPOLY]),
  182. fc="none", transform=ax4.transData)
  183. ax4.add_patch(pp1)
  184. def test_properties():
  185. ln = mlines.Line2D([], [])
  186. ln.properties() # Check that no warning is emitted.
  187. def test_setp():
  188. # Check empty list
  189. plt.setp([])
  190. plt.setp([[]])
  191. # Check arbitrary iterables
  192. fig, ax = plt.subplots()
  193. lines1 = ax.plot(range(3))
  194. lines2 = ax.plot(range(3))
  195. martist.setp(chain(lines1, lines2), 'lw', 5)
  196. plt.setp(ax.spines.values(), color='green')
  197. # Check *file* argument
  198. sio = io.StringIO()
  199. plt.setp(lines1, 'zorder', file=sio)
  200. assert sio.getvalue() == ' zorder: float\n'
  201. def test_None_zorder():
  202. fig, ax = plt.subplots()
  203. ln, = ax.plot(range(5), zorder=None)
  204. assert ln.get_zorder() == mlines.Line2D.zorder
  205. ln.set_zorder(123456)
  206. assert ln.get_zorder() == 123456
  207. ln.set_zorder(None)
  208. assert ln.get_zorder() == mlines.Line2D.zorder
  209. @pytest.mark.parametrize('accept_clause, expected', [
  210. ('', 'unknown'),
  211. ("ACCEPTS: [ '-' | '--' | '-.' ]", "[ '-' | '--' | '-.' ]"),
  212. ('ACCEPTS: Some description.', 'Some description.'),
  213. ('.. ACCEPTS: Some description.', 'Some description.'),
  214. ('arg : int', 'int'),
  215. ('*arg : int', 'int'),
  216. ('arg : int\nACCEPTS: Something else.', 'Something else. '),
  217. ])
  218. def test_artist_inspector_get_valid_values(accept_clause, expected):
  219. class TestArtist(martist.Artist):
  220. def set_f(self, arg):
  221. pass
  222. TestArtist.set_f.__doc__ = """
  223. Some text.
  224. %s
  225. """ % accept_clause
  226. valid_values = martist.ArtistInspector(TestArtist).get_valid_values('f')
  227. assert valid_values == expected
  228. def test_artist_inspector_get_aliases():
  229. # test the correct format and type of get_aliases method
  230. ai = martist.ArtistInspector(mlines.Line2D)
  231. aliases = ai.get_aliases()
  232. assert aliases["linewidth"] == {"lw"}
  233. def test_set_alpha():
  234. art = martist.Artist()
  235. with pytest.raises(TypeError, match='^alpha must be numeric or None'):
  236. art.set_alpha('string')
  237. with pytest.raises(TypeError, match='^alpha must be numeric or None'):
  238. art.set_alpha([1, 2, 3])
  239. with pytest.raises(ValueError, match="outside 0-1 range"):
  240. art.set_alpha(1.1)
  241. with pytest.raises(ValueError, match="outside 0-1 range"):
  242. art.set_alpha(np.nan)
  243. def test_set_alpha_for_array():
  244. art = martist.Artist()
  245. with pytest.raises(TypeError, match='^alpha must be numeric or None'):
  246. art._set_alpha_for_array('string')
  247. with pytest.raises(ValueError, match="outside 0-1 range"):
  248. art._set_alpha_for_array(1.1)
  249. with pytest.raises(ValueError, match="outside 0-1 range"):
  250. art._set_alpha_for_array(np.nan)
  251. with pytest.raises(ValueError, match="alpha must be between 0 and 1"):
  252. art._set_alpha_for_array([0.5, 1.1])
  253. with pytest.raises(ValueError, match="alpha must be between 0 and 1"):
  254. art._set_alpha_for_array([0.5, np.nan])
  255. def test_callbacks():
  256. def func(artist):
  257. func.counter += 1
  258. func.counter = 0
  259. art = martist.Artist()
  260. oid = art.add_callback(func)
  261. assert func.counter == 0
  262. art.pchanged() # must call the callback
  263. assert func.counter == 1
  264. art.set_zorder(10) # setting a property must also call the callback
  265. assert func.counter == 2
  266. art.remove_callback(oid)
  267. art.pchanged() # must not call the callback anymore
  268. assert func.counter == 2
  269. def test_set_signature():
  270. """Test autogenerated ``set()`` for Artist subclasses."""
  271. class MyArtist1(martist.Artist):
  272. def set_myparam1(self, val):
  273. pass
  274. assert hasattr(MyArtist1.set, '_autogenerated_signature')
  275. assert 'myparam1' in MyArtist1.set.__doc__
  276. class MyArtist2(MyArtist1):
  277. def set_myparam2(self, val):
  278. pass
  279. assert hasattr(MyArtist2.set, '_autogenerated_signature')
  280. assert 'myparam1' in MyArtist2.set.__doc__
  281. assert 'myparam2' in MyArtist2.set.__doc__
  282. def test_set_is_overwritten():
  283. """set() defined in Artist subclasses should not be overwritten."""
  284. class MyArtist3(martist.Artist):
  285. def set(self, **kwargs):
  286. """Not overwritten."""
  287. assert not hasattr(MyArtist3.set, '_autogenerated_signature')
  288. assert MyArtist3.set.__doc__ == "Not overwritten."
  289. class MyArtist4(MyArtist3):
  290. pass
  291. assert MyArtist4.set is MyArtist3.set
  292. def test_format_cursor_data_BoundaryNorm():
  293. """Test if cursor data is correct when using BoundaryNorm."""
  294. X = np.empty((3, 3))
  295. X[0, 0] = 0.9
  296. X[0, 1] = 0.99
  297. X[0, 2] = 0.999
  298. X[1, 0] = -1
  299. X[1, 1] = 0
  300. X[1, 2] = 1
  301. X[2, 0] = 0.09
  302. X[2, 1] = 0.009
  303. X[2, 2] = 0.0009
  304. # map range -1..1 to 0..256 in 0.1 steps
  305. fig, ax = plt.subplots()
  306. fig.suptitle("-1..1 to 0..256 in 0.1")
  307. norm = mcolors.BoundaryNorm(np.linspace(-1, 1, 20), 256)
  308. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  309. labels_list = [
  310. "[0.9]",
  311. "[1.]",
  312. "[1.]",
  313. "[-1.0]",
  314. "[0.0]",
  315. "[1.0]",
  316. "[0.09]",
  317. "[0.009]",
  318. "[0.0009]",
  319. ]
  320. for v, label in zip(X.flat, labels_list):
  321. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.1))
  322. assert img.format_cursor_data(v) == label
  323. plt.close()
  324. # map range -1..1 to 0..256 in 0.01 steps
  325. fig, ax = plt.subplots()
  326. fig.suptitle("-1..1 to 0..256 in 0.01")
  327. cmap = mpl.colormaps['RdBu_r'].resampled(200)
  328. norm = mcolors.BoundaryNorm(np.linspace(-1, 1, 200), 200)
  329. img = ax.imshow(X, cmap=cmap, norm=norm)
  330. labels_list = [
  331. "[0.90]",
  332. "[0.99]",
  333. "[1.0]",
  334. "[-1.00]",
  335. "[0.00]",
  336. "[1.00]",
  337. "[0.09]",
  338. "[0.009]",
  339. "[0.0009]",
  340. ]
  341. for v, label in zip(X.flat, labels_list):
  342. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.01))
  343. assert img.format_cursor_data(v) == label
  344. plt.close()
  345. # map range -1..1 to 0..256 in 0.01 steps
  346. fig, ax = plt.subplots()
  347. fig.suptitle("-1..1 to 0..256 in 0.001")
  348. cmap = mpl.colormaps['RdBu_r'].resampled(2000)
  349. norm = mcolors.BoundaryNorm(np.linspace(-1, 1, 2000), 2000)
  350. img = ax.imshow(X, cmap=cmap, norm=norm)
  351. labels_list = [
  352. "[0.900]",
  353. "[0.990]",
  354. "[0.999]",
  355. "[-1.000]",
  356. "[0.000]",
  357. "[1.000]",
  358. "[0.090]",
  359. "[0.009]",
  360. "[0.0009]",
  361. ]
  362. for v, label in zip(X.flat, labels_list):
  363. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.001))
  364. assert img.format_cursor_data(v) == label
  365. plt.close()
  366. # different testing data set with
  367. # out of bounds values for 0..1 range
  368. X = np.empty((7, 1))
  369. X[0] = -1.0
  370. X[1] = 0.0
  371. X[2] = 0.1
  372. X[3] = 0.5
  373. X[4] = 0.9
  374. X[5] = 1.0
  375. X[6] = 2.0
  376. labels_list = [
  377. "[-1.0]",
  378. "[0.0]",
  379. "[0.1]",
  380. "[0.5]",
  381. "[0.9]",
  382. "[1.0]",
  383. "[2.0]",
  384. ]
  385. fig, ax = plt.subplots()
  386. fig.suptitle("noclip, neither")
  387. norm = mcolors.BoundaryNorm(
  388. np.linspace(0, 1, 4, endpoint=True), 256, clip=False, extend='neither')
  389. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  390. for v, label in zip(X.flat, labels_list):
  391. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.33))
  392. assert img.format_cursor_data(v) == label
  393. plt.close()
  394. fig, ax = plt.subplots()
  395. fig.suptitle("noclip, min")
  396. norm = mcolors.BoundaryNorm(
  397. np.linspace(0, 1, 4, endpoint=True), 256, clip=False, extend='min')
  398. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  399. for v, label in zip(X.flat, labels_list):
  400. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.33))
  401. assert img.format_cursor_data(v) == label
  402. plt.close()
  403. fig, ax = plt.subplots()
  404. fig.suptitle("noclip, max")
  405. norm = mcolors.BoundaryNorm(
  406. np.linspace(0, 1, 4, endpoint=True), 256, clip=False, extend='max')
  407. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  408. for v, label in zip(X.flat, labels_list):
  409. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.33))
  410. assert img.format_cursor_data(v) == label
  411. plt.close()
  412. fig, ax = plt.subplots()
  413. fig.suptitle("noclip, both")
  414. norm = mcolors.BoundaryNorm(
  415. np.linspace(0, 1, 4, endpoint=True), 256, clip=False, extend='both')
  416. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  417. for v, label in zip(X.flat, labels_list):
  418. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.33))
  419. assert img.format_cursor_data(v) == label
  420. plt.close()
  421. fig, ax = plt.subplots()
  422. fig.suptitle("clip, neither")
  423. norm = mcolors.BoundaryNorm(
  424. np.linspace(0, 1, 4, endpoint=True), 256, clip=True, extend='neither')
  425. img = ax.imshow(X, cmap='RdBu_r', norm=norm)
  426. for v, label in zip(X.flat, labels_list):
  427. # label = "[{:-#.{}g}]".format(v, cbook._g_sig_digits(v, 0.33))
  428. assert img.format_cursor_data(v) == label
  429. plt.close()
  430. def test_auto_no_rasterize():
  431. class Gen1(martist.Artist):
  432. ...
  433. assert 'draw' in Gen1.__dict__
  434. assert Gen1.__dict__['draw'] is Gen1.draw
  435. class Gen2(Gen1):
  436. ...
  437. assert 'draw' not in Gen2.__dict__
  438. assert Gen2.draw is Gen1.draw
  439. def test_draw_wraper_forward_input():
  440. class TestKlass(martist.Artist):
  441. def draw(self, renderer, extra):
  442. return extra
  443. art = TestKlass()
  444. renderer = mbackend_bases.RendererBase()
  445. assert 'aardvark' == art.draw(renderer, 'aardvark')
  446. assert 'aardvark' == art.draw(renderer, extra='aardvark')