test_sankey.py 3.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105
  1. import pytest
  2. from numpy.testing import assert_allclose, assert_array_equal
  3. from matplotlib.sankey import Sankey
  4. from matplotlib.testing.decorators import check_figures_equal
  5. def test_sankey():
  6. # lets just create a sankey instance and check the code runs
  7. sankey = Sankey()
  8. sankey.add()
  9. def test_label():
  10. s = Sankey(flows=[0.25], labels=['First'], orientations=[-1])
  11. assert s.diagrams[0].texts[0].get_text() == 'First\n0.25'
  12. def test_format_using_callable():
  13. # test using callable by slightly incrementing above label example
  14. def show_three_decimal_places(value):
  15. return f'{value:.3f}'
  16. s = Sankey(flows=[0.25], labels=['First'], orientations=[-1],
  17. format=show_three_decimal_places)
  18. assert s.diagrams[0].texts[0].get_text() == 'First\n0.250'
  19. @pytest.mark.parametrize('kwargs, msg', (
  20. ({'gap': -1}, "'gap' is negative"),
  21. ({'gap': 1, 'radius': 2}, "'radius' is greater than 'gap'"),
  22. ({'head_angle': -1}, "'head_angle' is negative"),
  23. ({'tolerance': -1}, "'tolerance' is negative"),
  24. ({'flows': [1, -1], 'orientations': [-1, 0, 1]},
  25. r"The shapes of 'flows' \(2,\) and 'orientations'"),
  26. ({'flows': [1, -1], 'labels': ['a', 'b', 'c']},
  27. r"The shapes of 'flows' \(2,\) and 'labels'"),
  28. ))
  29. def test_sankey_errors(kwargs, msg):
  30. with pytest.raises(ValueError, match=msg):
  31. Sankey(**kwargs)
  32. @pytest.mark.parametrize('kwargs, msg', (
  33. ({'trunklength': -1}, "'trunklength' is negative"),
  34. ({'flows': [0.2, 0.3], 'prior': 0}, "The scaled sum of the connected"),
  35. ({'prior': -1}, "The index of the prior diagram is negative"),
  36. ({'prior': 1}, "The index of the prior diagram is 1"),
  37. ({'connect': (-1, 1), 'prior': 0}, "At least one of the connection"),
  38. ({'connect': (2, 1), 'prior': 0}, "The connection index to the source"),
  39. ({'connect': (1, 3), 'prior': 0}, "The connection index to this dia"),
  40. ({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
  41. 'orientations': [2]}, "The value of orientations"),
  42. ({'connect': (1, 1), 'prior': 0, 'flows': [-0.2, 0.2],
  43. 'pathlengths': [2]}, "The lengths of 'flows'"),
  44. ))
  45. def test_sankey_add_errors(kwargs, msg):
  46. sankey = Sankey()
  47. with pytest.raises(ValueError, match=msg):
  48. sankey.add(flows=[0.2, -0.2])
  49. sankey.add(**kwargs)
  50. def test_sankey2():
  51. s = Sankey(flows=[0.25, -0.25, 0.5, -0.5], labels=['Foo'],
  52. orientations=[-1], unit='Bar')
  53. sf = s.finish()
  54. assert_array_equal(sf[0].flows, [0.25, -0.25, 0.5, -0.5])
  55. assert sf[0].angles == [1, 3, 1, 3]
  56. assert all([text.get_text()[0:3] == 'Foo' for text in sf[0].texts])
  57. assert all([text.get_text()[-3:] == 'Bar' for text in sf[0].texts])
  58. assert sf[0].text.get_text() == ''
  59. assert_allclose(sf[0].tips,
  60. [(-1.375, -0.52011255),
  61. (1.375, -0.75506044),
  62. (-0.75, -0.41522509),
  63. (0.75, -0.8599479)])
  64. s = Sankey(flows=[0.25, -0.25, 0, 0.5, -0.5], labels=['Foo'],
  65. orientations=[-1], unit='Bar')
  66. sf = s.finish()
  67. assert_array_equal(sf[0].flows, [0.25, -0.25, 0, 0.5, -0.5])
  68. assert sf[0].angles == [1, 3, None, 1, 3]
  69. assert_allclose(sf[0].tips,
  70. [(-1.375, -0.52011255),
  71. (1.375, -0.75506044),
  72. (0, 0),
  73. (-0.75, -0.41522509),
  74. (0.75, -0.8599479)])
  75. @check_figures_equal(extensions=['png'])
  76. def test_sankey3(fig_test, fig_ref):
  77. ax_test = fig_test.gca()
  78. s_test = Sankey(ax=ax_test, flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
  79. orientations=[1, -1, 1, -1, 0, 0])
  80. s_test.finish()
  81. ax_ref = fig_ref.gca()
  82. s_ref = Sankey(ax=ax_ref)
  83. s_ref.add(flows=[0.25, -0.25, -0.25, 0.25, 0.5, -0.5],
  84. orientations=[1, -1, 1, -1, 0, 0])
  85. s_ref.finish()