123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294 |
- from sympy.core.basic import Basic
- from sympy.core.expr import Expr
- from sympy.core.symbol import Symbol
- from sympy.core.numbers import Integer, Rational, Float
- from sympy.printing.repr import srepr
- __all__ = ['dotprint']
- default_styles = (
- (Basic, {'color': 'blue', 'shape': 'ellipse'}),
- (Expr, {'color': 'black'})
- )
- slotClasses = (Symbol, Integer, Rational, Float)
- def purestr(x, with_args=False):
- """A string that follows ```obj = type(obj)(*obj.args)``` exactly.
- Parameters
- ==========
- with_args : boolean, optional
- If ``True``, there will be a second argument for the return
- value, which is a tuple containing ``purestr`` applied to each
- of the subnodes.
- If ``False``, there will not be a second argument for the
- return.
- Default is ``False``
- Examples
- ========
- >>> from sympy import Float, Symbol, MatrixSymbol
- >>> from sympy import Integer # noqa: F401
- >>> from sympy.core.symbol import Str # noqa: F401
- >>> from sympy.printing.dot import purestr
- Applying ``purestr`` for basic symbolic object:
- >>> code = purestr(Symbol('x'))
- >>> code
- "Symbol('x')"
- >>> eval(code) == Symbol('x')
- True
- For basic numeric object:
- >>> purestr(Float(2))
- "Float('2.0', precision=53)"
- For matrix symbol:
- >>> code = purestr(MatrixSymbol('x', 2, 2))
- >>> code
- "MatrixSymbol(Str('x'), Integer(2), Integer(2))"
- >>> eval(code) == MatrixSymbol('x', 2, 2)
- True
- With ``with_args=True``:
- >>> purestr(Float(2), with_args=True)
- ("Float('2.0', precision=53)", ())
- >>> purestr(MatrixSymbol('x', 2, 2), with_args=True)
- ("MatrixSymbol(Str('x'), Integer(2), Integer(2))",
- ("Str('x')", 'Integer(2)', 'Integer(2)'))
- """
- sargs = ()
- if not isinstance(x, Basic):
- rv = str(x)
- elif not x.args:
- rv = srepr(x)
- else:
- args = x.args
- sargs = tuple(map(purestr, args))
- rv = "%s(%s)"%(type(x).__name__, ', '.join(sargs))
- if with_args:
- rv = rv, sargs
- return rv
- def styleof(expr, styles=default_styles):
- """ Merge style dictionaries in order
- Examples
- ========
- >>> from sympy import Symbol, Basic, Expr, S
- >>> from sympy.printing.dot import styleof
- >>> styles = [(Basic, {'color': 'blue', 'shape': 'ellipse'}),
- ... (Expr, {'color': 'black'})]
- >>> styleof(Basic(S(1)), styles)
- {'color': 'blue', 'shape': 'ellipse'}
- >>> x = Symbol('x')
- >>> styleof(x + 1, styles) # this is an Expr
- {'color': 'black', 'shape': 'ellipse'}
- """
- style = dict()
- for typ, sty in styles:
- if isinstance(expr, typ):
- style.update(sty)
- return style
- def attrprint(d, delimiter=', '):
- """ Print a dictionary of attributes
- Examples
- ========
- >>> from sympy.printing.dot import attrprint
- >>> print(attrprint({'color': 'blue', 'shape': 'ellipse'}))
- "color"="blue", "shape"="ellipse"
- """
- return delimiter.join('"%s"="%s"'%item for item in sorted(d.items()))
- def dotnode(expr, styles=default_styles, labelfunc=str, pos=(), repeat=True):
- """ String defining a node
- Examples
- ========
- >>> from sympy.printing.dot import dotnode
- >>> from sympy.abc import x
- >>> print(dotnode(x))
- "Symbol('x')_()" ["color"="black", "label"="x", "shape"="ellipse"];
- """
- style = styleof(expr, styles)
- if isinstance(expr, Basic) and not expr.is_Atom:
- label = str(expr.__class__.__name__)
- else:
- label = labelfunc(expr)
- style['label'] = label
- expr_str = purestr(expr)
- if repeat:
- expr_str += '_%s' % str(pos)
- return '"%s" [%s];' % (expr_str, attrprint(style))
- def dotedges(expr, atom=lambda x: not isinstance(x, Basic), pos=(), repeat=True):
- """ List of strings for all expr->expr.arg pairs
- See the docstring of dotprint for explanations of the options.
- Examples
- ========
- >>> from sympy.printing.dot import dotedges
- >>> from sympy.abc import x
- >>> for e in dotedges(x+2):
- ... print(e)
- "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
- "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
- """
- if atom(expr):
- return []
- else:
- expr_str, arg_strs = purestr(expr, with_args=True)
- if repeat:
- expr_str += '_%s' % str(pos)
- arg_strs = ['%s_%s' % (a, str(pos + (i,)))
- for i, a in enumerate(arg_strs)]
- return ['"%s" -> "%s";' % (expr_str, a) for a in arg_strs]
- template = \
- """digraph{
- # Graph style
- %(graphstyle)s
- #########
- # Nodes #
- #########
- %(nodes)s
- #########
- # Edges #
- #########
- %(edges)s
- }"""
- _graphstyle = {'rankdir': 'TD', 'ordering': 'out'}
- def dotprint(expr,
- styles=default_styles, atom=lambda x: not isinstance(x, Basic),
- maxdepth=None, repeat=True, labelfunc=str, **kwargs):
- """DOT description of a SymPy expression tree
- Parameters
- ==========
- styles : list of lists composed of (Class, mapping), optional
- Styles for different classes.
- The default is
- .. code-block:: python
- (
- (Basic, {'color': 'blue', 'shape': 'ellipse'}),
- (Expr, {'color': 'black'})
- )
- atom : function, optional
- Function used to determine if an arg is an atom.
- A good choice is ``lambda x: not x.args``.
- The default is ``lambda x: not isinstance(x, Basic)``.
- maxdepth : integer, optional
- The maximum depth.
- The default is ``None``, meaning no limit.
- repeat : boolean, optional
- Whether to use different nodes for common subexpressions.
- The default is ``True``.
- For example, for ``x + x*y`` with ``repeat=True``, it will have
- two nodes for ``x``; with ``repeat=False``, it will have one
- node.
- .. warning::
- Even if a node appears twice in the same object like ``x`` in
- ``Pow(x, x)``, it will still only appear once.
- Hence, with ``repeat=False``, the number of arrows out of an
- object might not equal the number of args it has.
- labelfunc : function, optional
- A function to create a label for a given leaf node.
- The default is ``str``.
- Another good option is ``srepr``.
- For example with ``str``, the leaf nodes of ``x + 1`` are labeled,
- ``x`` and ``1``. With ``srepr``, they are labeled ``Symbol('x')``
- and ``Integer(1)``.
- **kwargs : optional
- Additional keyword arguments are included as styles for the graph.
- Examples
- ========
- >>> from sympy import dotprint
- >>> from sympy.abc import x
- >>> print(dotprint(x+2)) # doctest: +NORMALIZE_WHITESPACE
- digraph{
- <BLANKLINE>
- # Graph style
- "ordering"="out"
- "rankdir"="TD"
- <BLANKLINE>
- #########
- # Nodes #
- #########
- <BLANKLINE>
- "Add(Integer(2), Symbol('x'))_()" ["color"="black", "label"="Add", "shape"="ellipse"];
- "Integer(2)_(0,)" ["color"="black", "label"="2", "shape"="ellipse"];
- "Symbol('x')_(1,)" ["color"="black", "label"="x", "shape"="ellipse"];
- <BLANKLINE>
- #########
- # Edges #
- #########
- <BLANKLINE>
- "Add(Integer(2), Symbol('x'))_()" -> "Integer(2)_(0,)";
- "Add(Integer(2), Symbol('x'))_()" -> "Symbol('x')_(1,)";
- }
- """
- # repeat works by adding a signature tuple to the end of each node for its
- # position in the graph. For example, for expr = Add(x, Pow(x, 2)), the x in the
- # Pow will have the tuple (1, 0), meaning it is expr.args[1].args[0].
- graphstyle = _graphstyle.copy()
- graphstyle.update(kwargs)
- nodes = []
- edges = []
- def traverse(e, depth, pos=()):
- nodes.append(dotnode(e, styles, labelfunc=labelfunc, pos=pos, repeat=repeat))
- if maxdepth and depth >= maxdepth:
- return
- edges.extend(dotedges(e, atom=atom, pos=pos, repeat=repeat))
- [traverse(arg, depth+1, pos + (i,)) for i, arg in enumerate(e.args) if not atom(arg)]
- traverse(expr, 0)
- return template%{'graphstyle': attrprint(graphstyle, delimiter='\n'),
- 'nodes': '\n'.join(nodes),
- 'edges': '\n'.join(edges)}
|