stackplot.py 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127
  1. """
  2. Stacked area plot for 1D arrays inspired by Douglas Y'barbo's stackoverflow
  3. answer:
  4. https://stackoverflow.com/q/2225995/
  5. (https://stackoverflow.com/users/66549/doug)
  6. """
  7. import itertools
  8. import numpy as np
  9. from matplotlib import _api
  10. __all__ = ['stackplot']
  11. def stackplot(axes, x, *args,
  12. labels=(), colors=None, baseline='zero',
  13. **kwargs):
  14. """
  15. Draw a stacked area plot.
  16. Parameters
  17. ----------
  18. x : (N,) array-like
  19. y : (M, N) array-like
  20. The data is assumed to be unstacked. Each of the following
  21. calls is legal::
  22. stackplot(x, y) # where y has shape (M, N)
  23. stackplot(x, y1, y2, y3) # where y1, y2, y3, y4 have length N
  24. baseline : {'zero', 'sym', 'wiggle', 'weighted_wiggle'}
  25. Method used to calculate the baseline:
  26. - ``'zero'``: Constant zero baseline, i.e. a simple stacked plot.
  27. - ``'sym'``: Symmetric around zero and is sometimes called
  28. 'ThemeRiver'.
  29. - ``'wiggle'``: Minimizes the sum of the squared slopes.
  30. - ``'weighted_wiggle'``: Does the same but weights to account for
  31. size of each layer. It is also called 'Streamgraph'-layout. More
  32. details can be found at http://leebyron.com/streamgraph/.
  33. labels : list of str, optional
  34. A sequence of labels to assign to each data series. If unspecified,
  35. then no labels will be applied to artists.
  36. colors : list of color, optional
  37. A sequence of colors to be cycled through and used to color the stacked
  38. areas. The sequence need not be exactly the same length as the number
  39. of provided *y*, in which case the colors will repeat from the
  40. beginning.
  41. If not specified, the colors from the Axes property cycle will be used.
  42. data : indexable object, optional
  43. DATA_PARAMETER_PLACEHOLDER
  44. **kwargs
  45. All other keyword arguments are passed to `.Axes.fill_between`.
  46. Returns
  47. -------
  48. list of `.PolyCollection`
  49. A list of `.PolyCollection` instances, one for each element in the
  50. stacked area plot.
  51. """
  52. y = np.vstack(args)
  53. labels = iter(labels)
  54. if colors is not None:
  55. colors = itertools.cycle(colors)
  56. else:
  57. colors = (axes._get_lines.get_next_color() for _ in y)
  58. # Assume data passed has not been 'stacked', so stack it here.
  59. # We'll need a float buffer for the upcoming calculations.
  60. stack = np.cumsum(y, axis=0, dtype=np.promote_types(y.dtype, np.float32))
  61. _api.check_in_list(['zero', 'sym', 'wiggle', 'weighted_wiggle'],
  62. baseline=baseline)
  63. if baseline == 'zero':
  64. first_line = 0.
  65. elif baseline == 'sym':
  66. first_line = -np.sum(y, 0) * 0.5
  67. stack += first_line[None, :]
  68. elif baseline == 'wiggle':
  69. m = y.shape[0]
  70. first_line = (y * (m - 0.5 - np.arange(m)[:, None])).sum(0)
  71. first_line /= -m
  72. stack += first_line
  73. elif baseline == 'weighted_wiggle':
  74. total = np.sum(y, 0)
  75. # multiply by 1/total (or zero) to avoid infinities in the division:
  76. inv_total = np.zeros_like(total)
  77. mask = total > 0
  78. inv_total[mask] = 1.0 / total[mask]
  79. increase = np.hstack((y[:, 0:1], np.diff(y)))
  80. below_size = total - stack
  81. below_size += 0.5 * y
  82. move_up = below_size * inv_total
  83. move_up[:, 0] = 0.5
  84. center = (move_up - 0.5) * increase
  85. center = np.cumsum(center.sum(0))
  86. first_line = center - 0.5 * total
  87. stack += first_line
  88. # Color between x = 0 and the first array.
  89. coll = axes.fill_between(x, first_line, stack[0, :],
  90. facecolor=next(colors), label=next(labels, None),
  91. **kwargs)
  92. coll.sticky_edges.y[:] = [0]
  93. r = [coll]
  94. # Color between array i-1 and array i
  95. for i in range(len(y) - 1):
  96. r.append(axes.fill_between(x, stack[i, :], stack[i + 1, :],
  97. facecolor=next(colors),
  98. label=next(labels, None),
  99. **kwargs))
  100. return r