123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372 |
- """Matplotlib based plotting of quantum circuits.
- Todo:
- * Optimize printing of large circuits.
- * Get this to work with single gates.
- * Do a better job checking the form of circuits to make sure it is a Mul of
- Gates.
- * Get multi-target gates plotting.
- * Get initial and final states to plot.
- * Get measurements to plot. Might need to rethink measurement as a gate
- issue.
- * Get scale and figsize to be handled in a better way.
- * Write some tests/examples!
- """
- from typing import List, Dict as tDict
- from sympy.core.mul import Mul
- from sympy.external import import_module
- from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS
- from sympy.core.core import BasicMeta
- from sympy.core.assumptions import ManagedProperties
- __all__ = [
- 'CircuitPlot',
- 'circuit_plot',
- 'labeller',
- 'Mz',
- 'Mx',
- 'CreateOneQubitGate',
- 'CreateCGate',
- ]
- np = import_module('numpy')
- matplotlib = import_module(
- 'matplotlib', import_kwargs={'fromlist': ['pyplot']},
- catch=(RuntimeError,)) # This is raised in environments that have no display.
- if np and matplotlib:
- pyplot = matplotlib.pyplot
- Line2D = matplotlib.lines.Line2D
- Circle = matplotlib.patches.Circle
- #from matplotlib import rc
- #rc('text',usetex=True)
- class CircuitPlot:
- """A class for managing a circuit plot."""
- scale = 1.0
- fontsize = 20.0
- linewidth = 1.0
- control_radius = 0.05
- not_radius = 0.15
- swap_delta = 0.05
- labels = [] # type: List[str]
- inits = {} # type: tDict[str, str]
- label_buffer = 0.5
- def __init__(self, c, nqubits, **kwargs):
- if not np or not matplotlib:
- raise ImportError('numpy or matplotlib not available.')
- self.circuit = c
- self.ngates = len(self.circuit.args)
- self.nqubits = nqubits
- self.update(kwargs)
- self._create_grid()
- self._create_figure()
- self._plot_wires()
- self._plot_gates()
- self._finish()
- def update(self, kwargs):
- """Load the kwargs into the instance dict."""
- self.__dict__.update(kwargs)
- def _create_grid(self):
- """Create the grid of wires."""
- scale = self.scale
- wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float)
- gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float)
- self._wire_grid = wire_grid
- self._gate_grid = gate_grid
- def _create_figure(self):
- """Create the main matplotlib figure."""
- self._figure = pyplot.figure(
- figsize=(self.ngates*self.scale, self.nqubits*self.scale),
- facecolor='w',
- edgecolor='w'
- )
- ax = self._figure.add_subplot(
- 1, 1, 1,
- frameon=True
- )
- ax.set_axis_off()
- offset = 0.5*self.scale
- ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset)
- ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset)
- ax.set_aspect('equal')
- self._axes = ax
- def _plot_wires(self):
- """Plot the wires of the circuit diagram."""
- xstart = self._gate_grid[0]
- xstop = self._gate_grid[-1]
- xdata = (xstart - self.scale, xstop + self.scale)
- for i in range(self.nqubits):
- ydata = (self._wire_grid[i], self._wire_grid[i])
- line = Line2D(
- xdata, ydata,
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(line)
- if self.labels:
- init_label_buffer = 0
- if self.inits.get(self.labels[i]): init_label_buffer = 0.25
- self._axes.text(
- xdata[0]-self.label_buffer-init_label_buffer,ydata[0],
- render_label(self.labels[i],self.inits),
- size=self.fontsize,
- color='k',ha='center',va='center')
- self._plot_measured_wires()
- def _plot_measured_wires(self):
- ismeasured = self._measurements()
- xstop = self._gate_grid[-1]
- dy = 0.04 # amount to shift wires when doubled
- # Plot doubled wires after they are measured
- for im in ismeasured:
- xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale)
- ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy)
- line = Line2D(
- xdata, ydata,
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(line)
- # Also double any controlled lines off these wires
- for i,g in enumerate(self._gates()):
- if isinstance(g, (CGate, CGateS)):
- wires = g.controls + g.targets
- for wire in wires:
- if wire in ismeasured and \
- self._gate_grid[i] > self._gate_grid[ismeasured[wire]]:
- ydata = min(wires), max(wires)
- xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy
- line = Line2D(
- xdata, ydata,
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(line)
- def _gates(self):
- """Create a list of all gates in the circuit plot."""
- gates = []
- if isinstance(self.circuit, Mul):
- for g in reversed(self.circuit.args):
- if isinstance(g, Gate):
- gates.append(g)
- elif isinstance(self.circuit, Gate):
- gates.append(self.circuit)
- return gates
- def _plot_gates(self):
- """Iterate through the gates and plot each of them."""
- for i, gate in enumerate(self._gates()):
- gate.plot_gate(self, i)
- def _measurements(self):
- """Return a dict {i:j} where i is the index of the wire that has
- been measured, and j is the gate where the wire is measured.
- """
- ismeasured = {}
- for i,g in enumerate(self._gates()):
- if getattr(g,'measurement',False):
- for target in g.targets:
- if target in ismeasured:
- if ismeasured[target] > i:
- ismeasured[target] = i
- else:
- ismeasured[target] = i
- return ismeasured
- def _finish(self):
- # Disable clipping to make panning work well for large circuits.
- for o in self._figure.findobj():
- o.set_clip_on(False)
- def one_qubit_box(self, t, gate_idx, wire_idx):
- """Draw a box for a single qubit gate."""
- x = self._gate_grid[gate_idx]
- y = self._wire_grid[wire_idx]
- self._axes.text(
- x, y, t,
- color='k',
- ha='center',
- va='center',
- bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
- size=self.fontsize
- )
- def two_qubit_box(self, t, gate_idx, wire_idx):
- """Draw a box for a two qubit gate. Doesn't work yet.
- """
- # x = self._gate_grid[gate_idx]
- # y = self._wire_grid[wire_idx]+0.5
- print(self._gate_grid)
- print(self._wire_grid)
- # unused:
- # obj = self._axes.text(
- # x, y, t,
- # color='k',
- # ha='center',
- # va='center',
- # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
- # size=self.fontsize
- # )
- def control_line(self, gate_idx, min_wire, max_wire):
- """Draw a vertical control line."""
- xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx])
- ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire])
- line = Line2D(
- xdata, ydata,
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(line)
- def control_point(self, gate_idx, wire_idx):
- """Draw a control point."""
- x = self._gate_grid[gate_idx]
- y = self._wire_grid[wire_idx]
- radius = self.control_radius
- c = Circle(
- (x, y),
- radius*self.scale,
- ec='k',
- fc='k',
- fill=True,
- lw=self.linewidth
- )
- self._axes.add_patch(c)
- def not_point(self, gate_idx, wire_idx):
- """Draw a NOT gates as the circle with plus in the middle."""
- x = self._gate_grid[gate_idx]
- y = self._wire_grid[wire_idx]
- radius = self.not_radius
- c = Circle(
- (x, y),
- radius,
- ec='k',
- fc='w',
- fill=False,
- lw=self.linewidth
- )
- self._axes.add_patch(c)
- l = Line2D(
- (x, x), (y - radius, y + radius),
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(l)
- def swap_point(self, gate_idx, wire_idx):
- """Draw a swap point as a cross."""
- x = self._gate_grid[gate_idx]
- y = self._wire_grid[wire_idx]
- d = self.swap_delta
- l1 = Line2D(
- (x - d, x + d),
- (y - d, y + d),
- color='k',
- lw=self.linewidth
- )
- l2 = Line2D(
- (x - d, x + d),
- (y + d, y - d),
- color='k',
- lw=self.linewidth
- )
- self._axes.add_line(l1)
- self._axes.add_line(l2)
- def circuit_plot(c, nqubits, **kwargs):
- """Draw the circuit diagram for the circuit with nqubits.
- Parameters
- ==========
- c : circuit
- The circuit to plot. Should be a product of Gate instances.
- nqubits : int
- The number of qubits to include in the circuit. Must be at least
- as big as the largest `min_qubits`` of the gates.
- """
- return CircuitPlot(c, nqubits, **kwargs)
- def render_label(label, inits={}):
- """Slightly more flexible way to render labels.
- >>> from sympy.physics.quantum.circuitplot import render_label
- >>> render_label('q0')
- '$\\\\left|q0\\\\right\\\\rangle$'
- >>> render_label('q0', {'q0':'0'})
- '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$'
- """
- init = inits.get(label)
- if init:
- return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init)
- return r'$\left|%s\right\rangle$' % label
- def labeller(n, symbol='q'):
- """Autogenerate labels for wires of quantum circuits.
- Parameters
- ==========
- n : int
- number of qubits in the circuit.
- symbol : string
- A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc.
- >>> from sympy.physics.quantum.circuitplot import labeller
- >>> labeller(2)
- ['q_1', 'q_0']
- >>> labeller(3,'j')
- ['j_2', 'j_1', 'j_0']
- """
- return ['%s_%d' % (symbol,n-i-1) for i in range(n)]
- class Mz(OneQubitGate):
- """Mock-up of a z measurement gate.
- This is in circuitplot rather than gate.py because it's not a real
- gate, it just draws one.
- """
- measurement = True
- gate_name='Mz'
- gate_name_latex='M_z'
- class Mx(OneQubitGate):
- """Mock-up of an x measurement gate.
- This is in circuitplot rather than gate.py because it's not a real
- gate, it just draws one.
- """
- measurement = True
- gate_name='Mx'
- gate_name_latex='M_x'
- class CreateOneQubitGate(ManagedProperties):
- def __new__(mcl, name, latexname=None):
- if not latexname:
- latexname = name
- return BasicMeta.__new__(mcl, name + "Gate", (OneQubitGate,),
- {'gate_name': name, 'gate_name_latex': latexname})
- def CreateCGate(name, latexname=None):
- """Use a lexical closure to make a controlled gate.
- """
- if not latexname:
- latexname = name
- onequbitgate = CreateOneQubitGate(name, latexname)
- def ControlledGate(ctrls,target):
- return CGate(tuple(ctrls),onequbitgate(target))
- return ControlledGate
|