sankey.py 35 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814
  1. """
  2. Module for creating Sankey diagrams using Matplotlib.
  3. """
  4. import logging
  5. from types import SimpleNamespace
  6. import numpy as np
  7. import matplotlib as mpl
  8. from matplotlib.path import Path
  9. from matplotlib.patches import PathPatch
  10. from matplotlib.transforms import Affine2D
  11. from matplotlib import _docstring
  12. _log = logging.getLogger(__name__)
  13. __author__ = "Kevin L. Davies"
  14. __credits__ = ["Yannick Copin"]
  15. __license__ = "BSD"
  16. __version__ = "2011/09/16"
  17. # Angles [deg/90]
  18. RIGHT = 0
  19. UP = 1
  20. # LEFT = 2
  21. DOWN = 3
  22. class Sankey:
  23. """
  24. Sankey diagram.
  25. Sankey diagrams are a specific type of flow diagram, in which
  26. the width of the arrows is shown proportionally to the flow
  27. quantity. They are typically used to visualize energy or
  28. material or cost transfers between processes.
  29. `Wikipedia (6/1/2011) <https://en.wikipedia.org/wiki/Sankey_diagram>`_
  30. """
  31. def __init__(self, ax=None, scale=1.0, unit='', format='%G', gap=0.25,
  32. radius=0.1, shoulder=0.03, offset=0.15, head_angle=100,
  33. margin=0.4, tolerance=1e-6, **kwargs):
  34. """
  35. Create a new Sankey instance.
  36. The optional arguments listed below are applied to all subdiagrams so
  37. that there is consistent alignment and formatting.
  38. In order to draw a complex Sankey diagram, create an instance of
  39. `Sankey` by calling it without any kwargs::
  40. sankey = Sankey()
  41. Then add simple Sankey sub-diagrams::
  42. sankey.add() # 1
  43. sankey.add() # 2
  44. #...
  45. sankey.add() # n
  46. Finally, create the full diagram::
  47. sankey.finish()
  48. Or, instead, simply daisy-chain those calls::
  49. Sankey().add().add... .add().finish()
  50. Other Parameters
  51. ----------------
  52. ax : `~matplotlib.axes.Axes`
  53. Axes onto which the data should be plotted. If *ax* isn't
  54. provided, new Axes will be created.
  55. scale : float
  56. Scaling factor for the flows. *scale* sizes the width of the paths
  57. in order to maintain proper layout. The same scale is applied to
  58. all subdiagrams. The value should be chosen such that the product
  59. of the scale and the sum of the inputs is approximately 1.0 (and
  60. the product of the scale and the sum of the outputs is
  61. approximately -1.0).
  62. unit : str
  63. The physical unit associated with the flow quantities. If *unit*
  64. is None, then none of the quantities are labeled.
  65. format : str or callable
  66. A Python number formatting string or callable used to label the
  67. flows with their quantities (i.e., a number times a unit, where the
  68. unit is given). If a format string is given, the label will be
  69. ``format % quantity``. If a callable is given, it will be called
  70. with ``quantity`` as an argument.
  71. gap : float
  72. Space between paths that break in/break away to/from the top or
  73. bottom.
  74. radius : float
  75. Inner radius of the vertical paths.
  76. shoulder : float
  77. Size of the shoulders of output arrows.
  78. offset : float
  79. Text offset (from the dip or tip of the arrow).
  80. head_angle : float
  81. Angle, in degrees, of the arrow heads (and negative of the angle of
  82. the tails).
  83. margin : float
  84. Minimum space between Sankey outlines and the edge of the plot
  85. area.
  86. tolerance : float
  87. Acceptable maximum of the magnitude of the sum of flows. The
  88. magnitude of the sum of connected flows cannot be greater than
  89. *tolerance*.
  90. **kwargs
  91. Any additional keyword arguments will be passed to `add`, which
  92. will create the first subdiagram.
  93. See Also
  94. --------
  95. Sankey.add
  96. Sankey.finish
  97. Examples
  98. --------
  99. .. plot:: gallery/specialty_plots/sankey_basics.py
  100. """
  101. # Check the arguments.
  102. if gap < 0:
  103. raise ValueError(
  104. "'gap' is negative, which is not allowed because it would "
  105. "cause the paths to overlap")
  106. if radius > gap:
  107. raise ValueError(
  108. "'radius' is greater than 'gap', which is not allowed because "
  109. "it would cause the paths to overlap")
  110. if head_angle < 0:
  111. raise ValueError(
  112. "'head_angle' is negative, which is not allowed because it "
  113. "would cause inputs to look like outputs and vice versa")
  114. if tolerance < 0:
  115. raise ValueError(
  116. "'tolerance' is negative, but it must be a magnitude")
  117. # Create axes if necessary.
  118. if ax is None:
  119. import matplotlib.pyplot as plt
  120. fig = plt.figure()
  121. ax = fig.add_subplot(1, 1, 1, xticks=[], yticks=[])
  122. self.diagrams = []
  123. # Store the inputs.
  124. self.ax = ax
  125. self.unit = unit
  126. self.format = format
  127. self.scale = scale
  128. self.gap = gap
  129. self.radius = radius
  130. self.shoulder = shoulder
  131. self.offset = offset
  132. self.margin = margin
  133. self.pitch = np.tan(np.pi * (1 - head_angle / 180.0) / 2.0)
  134. self.tolerance = tolerance
  135. # Initialize the vertices of tight box around the diagram(s).
  136. self.extent = np.array((np.inf, -np.inf, np.inf, -np.inf))
  137. # If there are any kwargs, create the first subdiagram.
  138. if len(kwargs):
  139. self.add(**kwargs)
  140. def _arc(self, quadrant=0, cw=True, radius=1, center=(0, 0)):
  141. """
  142. Return the codes and vertices for a rotated, scaled, and translated
  143. 90 degree arc.
  144. Other Parameters
  145. ----------------
  146. quadrant : {0, 1, 2, 3}, default: 0
  147. Uses 0-based indexing (0, 1, 2, or 3).
  148. cw : bool, default: True
  149. If True, the arc vertices are produced clockwise; counter-clockwise
  150. otherwise.
  151. radius : float, default: 1
  152. The radius of the arc.
  153. center : (float, float), default: (0, 0)
  154. (x, y) tuple of the arc's center.
  155. """
  156. # Note: It would be possible to use matplotlib's transforms to rotate,
  157. # scale, and translate the arc, but since the angles are discrete,
  158. # it's just as easy and maybe more efficient to do it here.
  159. ARC_CODES = [Path.LINETO,
  160. Path.CURVE4,
  161. Path.CURVE4,
  162. Path.CURVE4,
  163. Path.CURVE4,
  164. Path.CURVE4,
  165. Path.CURVE4]
  166. # Vertices of a cubic Bezier curve approximating a 90 deg arc
  167. # These can be determined by Path.arc(0, 90).
  168. ARC_VERTICES = np.array([[1.00000000e+00, 0.00000000e+00],
  169. [1.00000000e+00, 2.65114773e-01],
  170. [8.94571235e-01, 5.19642327e-01],
  171. [7.07106781e-01, 7.07106781e-01],
  172. [5.19642327e-01, 8.94571235e-01],
  173. [2.65114773e-01, 1.00000000e+00],
  174. # Insignificant
  175. # [6.12303177e-17, 1.00000000e+00]])
  176. [0.00000000e+00, 1.00000000e+00]])
  177. if quadrant in (0, 2):
  178. if cw:
  179. vertices = ARC_VERTICES
  180. else:
  181. vertices = ARC_VERTICES[:, ::-1] # Swap x and y.
  182. else: # 1, 3
  183. # Negate x.
  184. if cw:
  185. # Swap x and y.
  186. vertices = np.column_stack((-ARC_VERTICES[:, 1],
  187. ARC_VERTICES[:, 0]))
  188. else:
  189. vertices = np.column_stack((-ARC_VERTICES[:, 0],
  190. ARC_VERTICES[:, 1]))
  191. if quadrant > 1:
  192. radius = -radius # Rotate 180 deg.
  193. return list(zip(ARC_CODES, radius * vertices +
  194. np.tile(center, (ARC_VERTICES.shape[0], 1))))
  195. def _add_input(self, path, angle, flow, length):
  196. """
  197. Add an input to a path and return its tip and label locations.
  198. """
  199. if angle is None:
  200. return [0, 0], [0, 0]
  201. else:
  202. x, y = path[-1][1] # Use the last point as a reference.
  203. dipdepth = (flow / 2) * self.pitch
  204. if angle == RIGHT:
  205. x -= length
  206. dip = [x + dipdepth, y + flow / 2.0]
  207. path.extend([(Path.LINETO, [x, y]),
  208. (Path.LINETO, dip),
  209. (Path.LINETO, [x, y + flow]),
  210. (Path.LINETO, [x + self.gap, y + flow])])
  211. label_location = [dip[0] - self.offset, dip[1]]
  212. else: # Vertical
  213. x -= self.gap
  214. if angle == UP:
  215. sign = 1
  216. else:
  217. sign = -1
  218. dip = [x - flow / 2, y - sign * (length - dipdepth)]
  219. if angle == DOWN:
  220. quadrant = 2
  221. else:
  222. quadrant = 1
  223. # Inner arc isn't needed if inner radius is zero
  224. if self.radius:
  225. path.extend(self._arc(quadrant=quadrant,
  226. cw=angle == UP,
  227. radius=self.radius,
  228. center=(x + self.radius,
  229. y - sign * self.radius)))
  230. else:
  231. path.append((Path.LINETO, [x, y]))
  232. path.extend([(Path.LINETO, [x, y - sign * length]),
  233. (Path.LINETO, dip),
  234. (Path.LINETO, [x - flow, y - sign * length])])
  235. path.extend(self._arc(quadrant=quadrant,
  236. cw=angle == DOWN,
  237. radius=flow + self.radius,
  238. center=(x + self.radius,
  239. y - sign * self.radius)))
  240. path.append((Path.LINETO, [x - flow, y + sign * flow]))
  241. label_location = [dip[0], dip[1] - sign * self.offset]
  242. return dip, label_location
  243. def _add_output(self, path, angle, flow, length):
  244. """
  245. Append an output to a path and return its tip and label locations.
  246. .. note:: *flow* is negative for an output.
  247. """
  248. if angle is None:
  249. return [0, 0], [0, 0]
  250. else:
  251. x, y = path[-1][1] # Use the last point as a reference.
  252. tipheight = (self.shoulder - flow / 2) * self.pitch
  253. if angle == RIGHT:
  254. x += length
  255. tip = [x + tipheight, y + flow / 2.0]
  256. path.extend([(Path.LINETO, [x, y]),
  257. (Path.LINETO, [x, y + self.shoulder]),
  258. (Path.LINETO, tip),
  259. (Path.LINETO, [x, y - self.shoulder + flow]),
  260. (Path.LINETO, [x, y + flow]),
  261. (Path.LINETO, [x - self.gap, y + flow])])
  262. label_location = [tip[0] + self.offset, tip[1]]
  263. else: # Vertical
  264. x += self.gap
  265. if angle == UP:
  266. sign, quadrant = 1, 3
  267. else:
  268. sign, quadrant = -1, 0
  269. tip = [x - flow / 2.0, y + sign * (length + tipheight)]
  270. # Inner arc isn't needed if inner radius is zero
  271. if self.radius:
  272. path.extend(self._arc(quadrant=quadrant,
  273. cw=angle == UP,
  274. radius=self.radius,
  275. center=(x - self.radius,
  276. y + sign * self.radius)))
  277. else:
  278. path.append((Path.LINETO, [x, y]))
  279. path.extend([(Path.LINETO, [x, y + sign * length]),
  280. (Path.LINETO, [x - self.shoulder,
  281. y + sign * length]),
  282. (Path.LINETO, tip),
  283. (Path.LINETO, [x + self.shoulder - flow,
  284. y + sign * length]),
  285. (Path.LINETO, [x - flow, y + sign * length])])
  286. path.extend(self._arc(quadrant=quadrant,
  287. cw=angle == DOWN,
  288. radius=self.radius - flow,
  289. center=(x - self.radius,
  290. y + sign * self.radius)))
  291. path.append((Path.LINETO, [x - flow, y + sign * flow]))
  292. label_location = [tip[0], tip[1] + sign * self.offset]
  293. return tip, label_location
  294. def _revert(self, path, first_action=Path.LINETO):
  295. """
  296. A path is not simply reversible by path[::-1] since the code
  297. specifies an action to take from the **previous** point.
  298. """
  299. reverse_path = []
  300. next_code = first_action
  301. for code, position in path[::-1]:
  302. reverse_path.append((next_code, position))
  303. next_code = code
  304. return reverse_path
  305. # This might be more efficient, but it fails because 'tuple' object
  306. # doesn't support item assignment:
  307. # path[1] = path[1][-1:0:-1]
  308. # path[1][0] = first_action
  309. # path[2] = path[2][::-1]
  310. # return path
  311. @_docstring.dedent_interpd
  312. def add(self, patchlabel='', flows=None, orientations=None, labels='',
  313. trunklength=1.0, pathlengths=0.25, prior=None, connect=(0, 0),
  314. rotation=0, **kwargs):
  315. """
  316. Add a simple Sankey diagram with flows at the same hierarchical level.
  317. Parameters
  318. ----------
  319. patchlabel : str
  320. Label to be placed at the center of the diagram.
  321. Note that *label* (not *patchlabel*) can be passed as keyword
  322. argument to create an entry in the legend.
  323. flows : list of float
  324. Array of flow values. By convention, inputs are positive and
  325. outputs are negative.
  326. Flows are placed along the top of the diagram from the inside out
  327. in order of their index within *flows*. They are placed along the
  328. sides of the diagram from the top down and along the bottom from
  329. the outside in.
  330. If the sum of the inputs and outputs is
  331. nonzero, the discrepancy will appear as a cubic Bézier curve along
  332. the top and bottom edges of the trunk.
  333. orientations : list of {-1, 0, 1}
  334. List of orientations of the flows (or a single orientation to be
  335. used for all flows). Valid values are 0 (inputs from
  336. the left, outputs to the right), 1 (from and to the top) or -1
  337. (from and to the bottom).
  338. labels : list of (str or None)
  339. List of labels for the flows (or a single label to be used for all
  340. flows). Each label may be *None* (no label), or a labeling string.
  341. If an entry is a (possibly empty) string, then the quantity for the
  342. corresponding flow will be shown below the string. However, if
  343. the *unit* of the main diagram is None, then quantities are never
  344. shown, regardless of the value of this argument.
  345. trunklength : float
  346. Length between the bases of the input and output groups (in
  347. data-space units).
  348. pathlengths : list of float
  349. List of lengths of the vertical arrows before break-in or after
  350. break-away. If a single value is given, then it will be applied to
  351. the first (inside) paths on the top and bottom, and the length of
  352. all other arrows will be justified accordingly. The *pathlengths*
  353. are not applied to the horizontal inputs and outputs.
  354. prior : int
  355. Index of the prior diagram to which this diagram should be
  356. connected.
  357. connect : (int, int)
  358. A (prior, this) tuple indexing the flow of the prior diagram and
  359. the flow of this diagram which should be connected. If this is the
  360. first diagram or *prior* is *None*, *connect* will be ignored.
  361. rotation : float
  362. Angle of rotation of the diagram in degrees. The interpretation of
  363. the *orientations* argument will be rotated accordingly (e.g., if
  364. *rotation* == 90, an *orientations* entry of 1 means to/from the
  365. left). *rotation* is ignored if this diagram is connected to an
  366. existing one (using *prior* and *connect*).
  367. Returns
  368. -------
  369. Sankey
  370. The current `.Sankey` instance.
  371. Other Parameters
  372. ----------------
  373. **kwargs
  374. Additional keyword arguments set `matplotlib.patches.PathPatch`
  375. properties, listed below. For example, one may want to use
  376. ``fill=False`` or ``label="A legend entry"``.
  377. %(Patch:kwdoc)s
  378. See Also
  379. --------
  380. Sankey.finish
  381. """
  382. # Check and preprocess the arguments.
  383. flows = np.array([1.0, -1.0]) if flows is None else np.array(flows)
  384. n = flows.shape[0] # Number of flows
  385. if rotation is None:
  386. rotation = 0
  387. else:
  388. # In the code below, angles are expressed in deg/90.
  389. rotation /= 90.0
  390. if orientations is None:
  391. orientations = 0
  392. try:
  393. orientations = np.broadcast_to(orientations, n)
  394. except ValueError:
  395. raise ValueError(
  396. f"The shapes of 'flows' {np.shape(flows)} and 'orientations' "
  397. f"{np.shape(orientations)} are incompatible"
  398. ) from None
  399. try:
  400. labels = np.broadcast_to(labels, n)
  401. except ValueError:
  402. raise ValueError(
  403. f"The shapes of 'flows' {np.shape(flows)} and 'labels' "
  404. f"{np.shape(labels)} are incompatible"
  405. ) from None
  406. if trunklength < 0:
  407. raise ValueError(
  408. "'trunklength' is negative, which is not allowed because it "
  409. "would cause poor layout")
  410. if abs(np.sum(flows)) > self.tolerance:
  411. _log.info("The sum of the flows is nonzero (%f; patchlabel=%r); "
  412. "is the system not at steady state?",
  413. np.sum(flows), patchlabel)
  414. scaled_flows = self.scale * flows
  415. gain = sum(max(flow, 0) for flow in scaled_flows)
  416. loss = sum(min(flow, 0) for flow in scaled_flows)
  417. if prior is not None:
  418. if prior < 0:
  419. raise ValueError("The index of the prior diagram is negative")
  420. if min(connect) < 0:
  421. raise ValueError(
  422. "At least one of the connection indices is negative")
  423. if prior >= len(self.diagrams):
  424. raise ValueError(
  425. f"The index of the prior diagram is {prior}, but there "
  426. f"are only {len(self.diagrams)} other diagrams")
  427. if connect[0] >= len(self.diagrams[prior].flows):
  428. raise ValueError(
  429. "The connection index to the source diagram is {}, but "
  430. "that diagram has only {} flows".format(
  431. connect[0], len(self.diagrams[prior].flows)))
  432. if connect[1] >= n:
  433. raise ValueError(
  434. f"The connection index to this diagram is {connect[1]}, "
  435. f"but this diagram has only {n} flows")
  436. if self.diagrams[prior].angles[connect[0]] is None:
  437. raise ValueError(
  438. f"The connection cannot be made, which may occur if the "
  439. f"magnitude of flow {connect[0]} of diagram {prior} is "
  440. f"less than the specified tolerance")
  441. flow_error = (self.diagrams[prior].flows[connect[0]] +
  442. flows[connect[1]])
  443. if abs(flow_error) >= self.tolerance:
  444. raise ValueError(
  445. f"The scaled sum of the connected flows is {flow_error}, "
  446. f"which is not within the tolerance ({self.tolerance})")
  447. # Determine if the flows are inputs.
  448. are_inputs = [None] * n
  449. for i, flow in enumerate(flows):
  450. if flow >= self.tolerance:
  451. are_inputs[i] = True
  452. elif flow <= -self.tolerance:
  453. are_inputs[i] = False
  454. else:
  455. _log.info(
  456. "The magnitude of flow %d (%f) is below the tolerance "
  457. "(%f).\nIt will not be shown, and it cannot be used in a "
  458. "connection.", i, flow, self.tolerance)
  459. # Determine the angles of the arrows (before rotation).
  460. angles = [None] * n
  461. for i, (orient, is_input) in enumerate(zip(orientations, are_inputs)):
  462. if orient == 1:
  463. if is_input:
  464. angles[i] = DOWN
  465. elif is_input is False:
  466. # Be specific since is_input can be None.
  467. angles[i] = UP
  468. elif orient == 0:
  469. if is_input is not None:
  470. angles[i] = RIGHT
  471. else:
  472. if orient != -1:
  473. raise ValueError(
  474. f"The value of orientations[{i}] is {orient}, "
  475. f"but it must be -1, 0, or 1")
  476. if is_input:
  477. angles[i] = UP
  478. elif is_input is False:
  479. angles[i] = DOWN
  480. # Justify the lengths of the paths.
  481. if np.iterable(pathlengths):
  482. if len(pathlengths) != n:
  483. raise ValueError(
  484. f"The lengths of 'flows' ({n}) and 'pathlengths' "
  485. f"({len(pathlengths)}) are incompatible")
  486. else: # Make pathlengths into a list.
  487. urlength = pathlengths
  488. ullength = pathlengths
  489. lrlength = pathlengths
  490. lllength = pathlengths
  491. d = dict(RIGHT=pathlengths)
  492. pathlengths = [d.get(angle, 0) for angle in angles]
  493. # Determine the lengths of the top-side arrows
  494. # from the middle outwards.
  495. for i, (angle, is_input, flow) in enumerate(zip(angles, are_inputs,
  496. scaled_flows)):
  497. if angle == DOWN and is_input:
  498. pathlengths[i] = ullength
  499. ullength += flow
  500. elif angle == UP and is_input is False:
  501. pathlengths[i] = urlength
  502. urlength -= flow # Flow is negative for outputs.
  503. # Determine the lengths of the bottom-side arrows
  504. # from the middle outwards.
  505. for i, (angle, is_input, flow) in enumerate(reversed(list(zip(
  506. angles, are_inputs, scaled_flows)))):
  507. if angle == UP and is_input:
  508. pathlengths[n - i - 1] = lllength
  509. lllength += flow
  510. elif angle == DOWN and is_input is False:
  511. pathlengths[n - i - 1] = lrlength
  512. lrlength -= flow
  513. # Determine the lengths of the left-side arrows
  514. # from the bottom upwards.
  515. has_left_input = False
  516. for i, (angle, is_input, spec) in enumerate(reversed(list(zip(
  517. angles, are_inputs, zip(scaled_flows, pathlengths))))):
  518. if angle == RIGHT:
  519. if is_input:
  520. if has_left_input:
  521. pathlengths[n - i - 1] = 0
  522. else:
  523. has_left_input = True
  524. # Determine the lengths of the right-side arrows
  525. # from the top downwards.
  526. has_right_output = False
  527. for i, (angle, is_input, spec) in enumerate(zip(
  528. angles, are_inputs, list(zip(scaled_flows, pathlengths)))):
  529. if angle == RIGHT:
  530. if is_input is False:
  531. if has_right_output:
  532. pathlengths[i] = 0
  533. else:
  534. has_right_output = True
  535. # Begin the subpaths, and smooth the transition if the sum of the flows
  536. # is nonzero.
  537. urpath = [(Path.MOVETO, [(self.gap - trunklength / 2.0), # Upper right
  538. gain / 2.0]),
  539. (Path.LINETO, [(self.gap - trunklength / 2.0) / 2.0,
  540. gain / 2.0]),
  541. (Path.CURVE4, [(self.gap - trunklength / 2.0) / 8.0,
  542. gain / 2.0]),
  543. (Path.CURVE4, [(trunklength / 2.0 - self.gap) / 8.0,
  544. -loss / 2.0]),
  545. (Path.LINETO, [(trunklength / 2.0 - self.gap) / 2.0,
  546. -loss / 2.0]),
  547. (Path.LINETO, [(trunklength / 2.0 - self.gap),
  548. -loss / 2.0])]
  549. llpath = [(Path.LINETO, [(trunklength / 2.0 - self.gap), # Lower left
  550. loss / 2.0]),
  551. (Path.LINETO, [(trunklength / 2.0 - self.gap) / 2.0,
  552. loss / 2.0]),
  553. (Path.CURVE4, [(trunklength / 2.0 - self.gap) / 8.0,
  554. loss / 2.0]),
  555. (Path.CURVE4, [(self.gap - trunklength / 2.0) / 8.0,
  556. -gain / 2.0]),
  557. (Path.LINETO, [(self.gap - trunklength / 2.0) / 2.0,
  558. -gain / 2.0]),
  559. (Path.LINETO, [(self.gap - trunklength / 2.0),
  560. -gain / 2.0])]
  561. lrpath = [(Path.LINETO, [(trunklength / 2.0 - self.gap), # Lower right
  562. loss / 2.0])]
  563. ulpath = [(Path.LINETO, [self.gap - trunklength / 2.0, # Upper left
  564. gain / 2.0])]
  565. # Add the subpaths and assign the locations of the tips and labels.
  566. tips = np.zeros((n, 2))
  567. label_locations = np.zeros((n, 2))
  568. # Add the top-side inputs and outputs from the middle outwards.
  569. for i, (angle, is_input, spec) in enumerate(zip(
  570. angles, are_inputs, list(zip(scaled_flows, pathlengths)))):
  571. if angle == DOWN and is_input:
  572. tips[i, :], label_locations[i, :] = self._add_input(
  573. ulpath, angle, *spec)
  574. elif angle == UP and is_input is False:
  575. tips[i, :], label_locations[i, :] = self._add_output(
  576. urpath, angle, *spec)
  577. # Add the bottom-side inputs and outputs from the middle outwards.
  578. for i, (angle, is_input, spec) in enumerate(reversed(list(zip(
  579. angles, are_inputs, list(zip(scaled_flows, pathlengths)))))):
  580. if angle == UP and is_input:
  581. tip, label_location = self._add_input(llpath, angle, *spec)
  582. tips[n - i - 1, :] = tip
  583. label_locations[n - i - 1, :] = label_location
  584. elif angle == DOWN and is_input is False:
  585. tip, label_location = self._add_output(lrpath, angle, *spec)
  586. tips[n - i - 1, :] = tip
  587. label_locations[n - i - 1, :] = label_location
  588. # Add the left-side inputs from the bottom upwards.
  589. has_left_input = False
  590. for i, (angle, is_input, spec) in enumerate(reversed(list(zip(
  591. angles, are_inputs, list(zip(scaled_flows, pathlengths)))))):
  592. if angle == RIGHT and is_input:
  593. if not has_left_input:
  594. # Make sure the lower path extends
  595. # at least as far as the upper one.
  596. if llpath[-1][1][0] > ulpath[-1][1][0]:
  597. llpath.append((Path.LINETO, [ulpath[-1][1][0],
  598. llpath[-1][1][1]]))
  599. has_left_input = True
  600. tip, label_location = self._add_input(llpath, angle, *spec)
  601. tips[n - i - 1, :] = tip
  602. label_locations[n - i - 1, :] = label_location
  603. # Add the right-side outputs from the top downwards.
  604. has_right_output = False
  605. for i, (angle, is_input, spec) in enumerate(zip(
  606. angles, are_inputs, list(zip(scaled_flows, pathlengths)))):
  607. if angle == RIGHT and is_input is False:
  608. if not has_right_output:
  609. # Make sure the upper path extends
  610. # at least as far as the lower one.
  611. if urpath[-1][1][0] < lrpath[-1][1][0]:
  612. urpath.append((Path.LINETO, [lrpath[-1][1][0],
  613. urpath[-1][1][1]]))
  614. has_right_output = True
  615. tips[i, :], label_locations[i, :] = self._add_output(
  616. urpath, angle, *spec)
  617. # Trim any hanging vertices.
  618. if not has_left_input:
  619. ulpath.pop()
  620. llpath.pop()
  621. if not has_right_output:
  622. lrpath.pop()
  623. urpath.pop()
  624. # Concatenate the subpaths in the correct order (clockwise from top).
  625. path = (urpath + self._revert(lrpath) + llpath + self._revert(ulpath) +
  626. [(Path.CLOSEPOLY, urpath[0][1])])
  627. # Create a patch with the Sankey outline.
  628. codes, vertices = zip(*path)
  629. vertices = np.array(vertices)
  630. def _get_angle(a, r):
  631. if a is None:
  632. return None
  633. else:
  634. return a + r
  635. if prior is None:
  636. if rotation != 0: # By default, none of this is needed.
  637. angles = [_get_angle(angle, rotation) for angle in angles]
  638. rotate = Affine2D().rotate_deg(rotation * 90).transform_affine
  639. tips = rotate(tips)
  640. label_locations = rotate(label_locations)
  641. vertices = rotate(vertices)
  642. text = self.ax.text(0, 0, s=patchlabel, ha='center', va='center')
  643. else:
  644. rotation = (self.diagrams[prior].angles[connect[0]] -
  645. angles[connect[1]])
  646. angles = [_get_angle(angle, rotation) for angle in angles]
  647. rotate = Affine2D().rotate_deg(rotation * 90).transform_affine
  648. tips = rotate(tips)
  649. offset = self.diagrams[prior].tips[connect[0]] - tips[connect[1]]
  650. translate = Affine2D().translate(*offset).transform_affine
  651. tips = translate(tips)
  652. label_locations = translate(rotate(label_locations))
  653. vertices = translate(rotate(vertices))
  654. kwds = dict(s=patchlabel, ha='center', va='center')
  655. text = self.ax.text(*offset, **kwds)
  656. if mpl.rcParams['_internal.classic_mode']:
  657. fc = kwargs.pop('fc', kwargs.pop('facecolor', '#bfd1d4'))
  658. lw = kwargs.pop('lw', kwargs.pop('linewidth', 0.5))
  659. else:
  660. fc = kwargs.pop('fc', kwargs.pop('facecolor', None))
  661. lw = kwargs.pop('lw', kwargs.pop('linewidth', None))
  662. if fc is None:
  663. fc = self.ax._get_patches_for_fill.get_next_color()
  664. patch = PathPatch(Path(vertices, codes), fc=fc, lw=lw, **kwargs)
  665. self.ax.add_patch(patch)
  666. # Add the path labels.
  667. texts = []
  668. for number, angle, label, location in zip(flows, angles, labels,
  669. label_locations):
  670. if label is None or angle is None:
  671. label = ''
  672. elif self.unit is not None:
  673. if isinstance(self.format, str):
  674. quantity = self.format % abs(number) + self.unit
  675. elif callable(self.format):
  676. quantity = self.format(number)
  677. else:
  678. raise TypeError(
  679. 'format must be callable or a format string')
  680. if label != '':
  681. label += "\n"
  682. label += quantity
  683. texts.append(self.ax.text(x=location[0], y=location[1],
  684. s=label,
  685. ha='center', va='center'))
  686. # Text objects are placed even they are empty (as long as the magnitude
  687. # of the corresponding flow is larger than the tolerance) in case the
  688. # user wants to provide labels later.
  689. # Expand the size of the diagram if necessary.
  690. self.extent = (min(np.min(vertices[:, 0]),
  691. np.min(label_locations[:, 0]),
  692. self.extent[0]),
  693. max(np.max(vertices[:, 0]),
  694. np.max(label_locations[:, 0]),
  695. self.extent[1]),
  696. min(np.min(vertices[:, 1]),
  697. np.min(label_locations[:, 1]),
  698. self.extent[2]),
  699. max(np.max(vertices[:, 1]),
  700. np.max(label_locations[:, 1]),
  701. self.extent[3]))
  702. # Include both vertices _and_ label locations in the extents; there are
  703. # where either could determine the margins (e.g., arrow shoulders).
  704. # Add this diagram as a subdiagram.
  705. self.diagrams.append(
  706. SimpleNamespace(patch=patch, flows=flows, angles=angles, tips=tips,
  707. text=text, texts=texts))
  708. # Allow a daisy-chained call structure (see docstring for the class).
  709. return self
  710. def finish(self):
  711. """
  712. Adjust the axes and return a list of information about the Sankey
  713. subdiagram(s).
  714. Returns a list of subdiagrams with the following fields:
  715. ======== =============================================================
  716. Field Description
  717. ======== =============================================================
  718. *patch* Sankey outline (a `~matplotlib.patches.PathPatch`).
  719. *flows* Flow values (positive for input, negative for output).
  720. *angles* List of angles of the arrows [deg/90].
  721. For example, if the diagram has not been rotated,
  722. an input to the top side has an angle of 3 (DOWN),
  723. and an output from the top side has an angle of 1 (UP).
  724. If a flow has been skipped (because its magnitude is less
  725. than *tolerance*), then its angle will be *None*.
  726. *tips* (N, 2)-array of the (x, y) positions of the tips (or "dips")
  727. of the flow paths.
  728. If the magnitude of a flow is less the *tolerance* of this
  729. `Sankey` instance, the flow is skipped and its tip will be at
  730. the center of the diagram.
  731. *text* `.Text` instance for the diagram label.
  732. *texts* List of `.Text` instances for the flow labels.
  733. ======== =============================================================
  734. See Also
  735. --------
  736. Sankey.add
  737. """
  738. self.ax.axis([self.extent[0] - self.margin,
  739. self.extent[1] + self.margin,
  740. self.extent[2] - self.margin,
  741. self.extent[3] + self.margin])
  742. self.ax.set_aspect('equal', adjustable='datalim')
  743. return self.diagrams