dot.py 8.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294
  1. from sympy.core.basic import Basic
  2. from sympy.core.expr import Expr
  3. from sympy.core.symbol import Symbol
  4. from sympy.core.numbers import Integer, Rational, Float
  5. from sympy.printing.repr import srepr
  6. __all__ = ['dotprint']
  7. default_styles = (
  8. (Basic, {'color': 'blue', 'shape': 'ellipse'}),
  9. (Expr, {'color': 'black'})
  10. )
  11. slotClasses = (Symbol, Integer, Rational, Float)
  12. def purestr(x, with_args=False):
  13. """A string that follows ```obj = type(obj)(*obj.args)``` exactly.
  14. Parameters
  15. ==========
  16. with_args : boolean, optional
  17. If ``True``, there will be a second argument for the return
  18. value, which is a tuple containing ``purestr`` applied to each
  19. of the subnodes.
  20. If ``False``, there will not be a second argument for the
  21. return.
  22. Default is ``False``
  23. Examples
  24. ========
  25. >>> from sympy import Float, Symbol, MatrixSymbol
  26. >>> from sympy import Integer # noqa: F401
  27. >>> from sympy.core.symbol import Str # noqa: F401
  28. >>> from sympy.printing.dot import purestr
  29. Applying ``purestr`` for basic symbolic object:
  30. >>> code = purestr(Symbol('x'))
  31. >>> code
  32. "Symbol('x')"
  33. >>> eval(code) == Symbol('x')
  34. True
  35. For basic numeric object:
  36. >>> purestr(Float(2))
  37. "Float('2.0', precision=53)"
  38. For matrix symbol:
  39. >>> code = purestr(MatrixSymbol('x', 2, 2))
  40. >>> code
  41. "MatrixSymbol(Str('x'), Integer(2), Integer(2))"
  42. >>> eval(code) == MatrixSymbol('x', 2, 2)
  43. True
  44. With ``with_args=True``:
  45. >>> purestr(Float(2), with_args=True)
  46. ("Float('2.0', precision=53)", ())
  47. >>> purestr(MatrixSymbol('x', 2, 2), with_args=True)
  48. ("MatrixSymbol(Str('x'), Integer(2), Integer(2))",
  49. ("Str('x')", 'Integer(2)', 'Integer(2)'))
  50. """
  51. sargs = ()
  52. if not isinstance(x, Basic):
  53. rv = str(x)
  54. elif not x.args:
  55. rv = srepr(x)
  56. else:
  57. args = x.args
  58. sargs = tuple(map(purestr, args))
  59. rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs))
  60. if with_args:
  61. rv = rv, sargs
  62. return rv
  63. def styleof(expr, styles=default_styles):
  64. """ Merge style dictionaries in order
  65. Examples
  66. ========
  67. >>> from sympy import Symbol, Basic, Expr, S
  68. >>> from sympy.printing.dot import styleof
  69. >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
  70. ... (Expr, {'color': 'black'})]
  71. >>> styleof(Basic(S(1)), styles)
  72. {'color': 'blue', 'shape': 'ellipse'}
  73. >>> x = Symbol('x')
  74. >>> styleof(x + 1, styles) # this is an Expr
  75. {'color': 'black', 'shape': 'ellipse'}
  76. """
  77. style = dict()
  78. for typ, sty in styles:
  79. if isinstance(expr, typ):
  80. style.update(sty)
  81. return style
  82. def attrprint(d, delimiter=', '):
  83. """ Print a dictionary of attributes
  84. Examples
  85. ========
  86. >>> from sympy.printing.dot import attrprint
  87. >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
  88. "color"="blue", "shape"="ellipse"
  89. """
  90. return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
  91. def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
  92. """ String defining a node
  93. Examples
  94. ========
  95. >>> from sympy.printing.dot import dotnode
  96. >>> from sympy.abc import x
  97. >>> print(dotnode(x))
  98. "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"];
  99. """
  100. style = styleof(expr, styles)
  101. if isinstance(expr, Basic) and not expr.is_Atom:
  102. label = str(expr.__class__.__name__)
  103. else:
  104. label = labelfunc(expr)
  105. style['label'] = label
  106. expr_str = purestr(expr)
  107. if repeat:
  108. expr_str += '_%s' % str(pos)
  109. return '"%s" [%s];' % (expr_str, attrprint(style))
  110. def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
  111. """ List of strings for all expr->expr.arg pairs
  112. See the docstring of dotprint for explanations of the options.
  113. Examples
  114. ========
  115. >>> from sympy.printing.dot import dotedges
  116. >>> from sympy.abc import x
  117. >>> for e in dotedges(x+2):
  118. ... print(e)
  119. "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
  120. "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
  121. """
  122. if atom(expr):
  123. return []
  124. else:
  125. expr_str, arg_strs = purestr(expr, with_args=True)
  126. if repeat:
  127. expr_str += '_%s' % str(pos)
  128. arg_strs = ['%s_%s' % (a, str(pos + (i,)))
  129. for i, a in enumerate(arg_strs)]
  130. return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs]
  131. template = \
  132. """digraph{
  133. # Graph style
  134. %(graphstyle)s
  135. #########
  136. # Nodes #
  137. #########
  138. %(nodes)s
  139. #########
  140. # Edges #
  141. #########
  142. %(edges)s
  143. }"""
  144. _graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
  145. def dotprint(expr,
  146. styles=default_styles, atom=lambda x: not isinstance(x, Basic),
  147. maxdepth=None, repeat=True, labelfunc=str, **kwargs):
  148. """DOT description of a SymPy expression tree
  149. Parameters
  150. ==========
  151. styles : list of lists composed of (Class, mapping), optional
  152. Styles for different classes.
  153. The default is
  154. .. code-block:: python
  155. (
  156. (Basic, {'color': 'blue', 'shape': 'ellipse'}),
  157. (Expr, {'color': 'black'})
  158. )
  159. atom : function, optional
  160. Function used to determine if an arg is an atom.
  161. A good choice is ``lambda x: not x.args``.
  162. The default is ``lambda x: not isinstance(x, Basic)``.
  163. maxdepth : integer, optional
  164. The maximum depth.
  165. The default is ``None``, meaning no limit.
  166. repeat : boolean, optional
  167. Whether to use different nodes for common subexpressions.
  168. The default is ``True``.
  169. For example, for ``x + x*y`` with ``repeat=True``, it will have
  170. two nodes for ``x``; with ``repeat=False``, it will have one
  171. node.
  172. .. warning::
  173. Even if a node appears twice in the same object like ``x`` in
  174. ``Pow(x, x)``, it will still only appear once.
  175. Hence, with ``repeat=False``, the number of arrows out of an
  176. object might not equal the number of args it has.
  177. labelfunc : function, optional
  178. A function to create a label for a given leaf node.
  179. The default is ``str``.
  180. Another good option is ``srepr``.
  181. For example with ``str``, the leaf nodes of ``x + 1`` are labeled,
  182. ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')``
  183. and ``Integer(1)``.
  184. **kwargs : optional
  185. Additional keyword arguments are included as styles for the graph.
  186. Examples
  187. ========
  188. >>> from sympy import dotprint
  189. >>> from sympy.abc import x
  190. >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
  191. digraph{
  192. <BLANKLINE>
  193. # Graph style
  194. "ordering"="out"
  195. "rankdir"="TD"
  196. <BLANKLINE>
  197. #########
  198. # Nodes #
  199. #########
  200. <BLANKLINE>
  201. "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
  202. "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
  203. "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
  204. <BLANKLINE>
  205. #########
  206. # Edges #
  207. #########
  208. <BLANKLINE>
  209. "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
  210. "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
  211. }
  212. """
  213. # repeat works by adding a signature tuple to the end of each node for its
  214. # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
  215. # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
  216. graphstyle = _graphstyle.copy()
  217. graphstyle.update(kwargs)
  218. nodes = []
  219. edges = []
  220. def traverse(e, depth, pos=()):
  221. nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
  222. if maxdepth and depth >= maxdepth:
  223. return
  224. edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
  225. [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
  226. traverse(expr, 0)
  227. return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
  228. 'nodes': '\n'.join(nodes),
  229. 'edges': '\n'.join(edges)}