circuitplot.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372
  1. """Matplotlib based plotting of quantum circuits.
  2. Todo:
  3. * Optimize printing of large circuits.
  4. * Get this to work with single gates.
  5. * Do a better job checking the form of circuits to make sure it is a Mul of
  6. Gates.
  7. * Get multi-target gates plotting.
  8. * Get initial and final states to plot.
  9. * Get measurements to plot. Might need to rethink measurement as a gate
  10. issue.
  11. * Get scale and figsize to be handled in a better way.
  12. * Write some tests/examples!
  13. """
  14. from typing import List, Dict as tDict
  15. from sympy.core.mul import Mul
  16. from sympy.external import import_module
  17. from sympy.physics.quantum.gate import Gate, OneQubitGate, CGate, CGateS
  18. from sympy.core.core import BasicMeta
  19. from sympy.core.assumptions import ManagedProperties
  20. __all__ = [
  21. 'CircuitPlot',
  22. 'circuit_plot',
  23. 'labeller',
  24. 'Mz',
  25. 'Mx',
  26. 'CreateOneQubitGate',
  27. 'CreateCGate',
  28. ]
  29. np = import_module('numpy')
  30. matplotlib = import_module(
  31. 'matplotlib', import_kwargs={'fromlist': ['pyplot']},
  32. catch=(RuntimeError,)) # This is raised in environments that have no display.
  33. if np and matplotlib:
  34. pyplot = matplotlib.pyplot
  35. Line2D = matplotlib.lines.Line2D
  36. Circle = matplotlib.patches.Circle
  37. #from matplotlib import rc
  38. #rc('text',usetex=True)
  39. class CircuitPlot:
  40. """A class for managing a circuit plot."""
  41. scale = 1.0
  42. fontsize = 20.0
  43. linewidth = 1.0
  44. control_radius = 0.05
  45. not_radius = 0.15
  46. swap_delta = 0.05
  47. labels = [] # type: List[str]
  48. inits = {} # type: tDict[str, str]
  49. label_buffer = 0.5
  50. def __init__(self, c, nqubits, **kwargs):
  51. if not np or not matplotlib:
  52. raise ImportError('numpy or matplotlib not available.')
  53. self.circuit = c
  54. self.ngates = len(self.circuit.args)
  55. self.nqubits = nqubits
  56. self.update(kwargs)
  57. self._create_grid()
  58. self._create_figure()
  59. self._plot_wires()
  60. self._plot_gates()
  61. self._finish()
  62. def update(self, kwargs):
  63. """Load the kwargs into the instance dict."""
  64. self.__dict__.update(kwargs)
  65. def _create_grid(self):
  66. """Create the grid of wires."""
  67. scale = self.scale
  68. wire_grid = np.arange(0.0, self.nqubits*scale, scale, dtype=float)
  69. gate_grid = np.arange(0.0, self.ngates*scale, scale, dtype=float)
  70. self._wire_grid = wire_grid
  71. self._gate_grid = gate_grid
  72. def _create_figure(self):
  73. """Create the main matplotlib figure."""
  74. self._figure = pyplot.figure(
  75. figsize=(self.ngates*self.scale, self.nqubits*self.scale),
  76. facecolor='w',
  77. edgecolor='w'
  78. )
  79. ax = self._figure.add_subplot(
  80. 1, 1, 1,
  81. frameon=True
  82. )
  83. ax.set_axis_off()
  84. offset = 0.5*self.scale
  85. ax.set_xlim(self._gate_grid[0] - offset, self._gate_grid[-1] + offset)
  86. ax.set_ylim(self._wire_grid[0] - offset, self._wire_grid[-1] + offset)
  87. ax.set_aspect('equal')
  88. self._axes = ax
  89. def _plot_wires(self):
  90. """Plot the wires of the circuit diagram."""
  91. xstart = self._gate_grid[0]
  92. xstop = self._gate_grid[-1]
  93. xdata = (xstart - self.scale, xstop + self.scale)
  94. for i in range(self.nqubits):
  95. ydata = (self._wire_grid[i], self._wire_grid[i])
  96. line = Line2D(
  97. xdata, ydata,
  98. color='k',
  99. lw=self.linewidth
  100. )
  101. self._axes.add_line(line)
  102. if self.labels:
  103. init_label_buffer = 0
  104. if self.inits.get(self.labels[i]): init_label_buffer = 0.25
  105. self._axes.text(
  106. xdata[0]-self.label_buffer-init_label_buffer,ydata[0],
  107. render_label(self.labels[i],self.inits),
  108. size=self.fontsize,
  109. color='k',ha='center',va='center')
  110. self._plot_measured_wires()
  111. def _plot_measured_wires(self):
  112. ismeasured = self._measurements()
  113. xstop = self._gate_grid[-1]
  114. dy = 0.04 # amount to shift wires when doubled
  115. # Plot doubled wires after they are measured
  116. for im in ismeasured:
  117. xdata = (self._gate_grid[ismeasured[im]],xstop+self.scale)
  118. ydata = (self._wire_grid[im]+dy,self._wire_grid[im]+dy)
  119. line = Line2D(
  120. xdata, ydata,
  121. color='k',
  122. lw=self.linewidth
  123. )
  124. self._axes.add_line(line)
  125. # Also double any controlled lines off these wires
  126. for i,g in enumerate(self._gates()):
  127. if isinstance(g, (CGate, CGateS)):
  128. wires = g.controls + g.targets
  129. for wire in wires:
  130. if wire in ismeasured and \
  131. self._gate_grid[i] > self._gate_grid[ismeasured[wire]]:
  132. ydata = min(wires), max(wires)
  133. xdata = self._gate_grid[i]-dy, self._gate_grid[i]-dy
  134. line = Line2D(
  135. xdata, ydata,
  136. color='k',
  137. lw=self.linewidth
  138. )
  139. self._axes.add_line(line)
  140. def _gates(self):
  141. """Create a list of all gates in the circuit plot."""
  142. gates = []
  143. if isinstance(self.circuit, Mul):
  144. for g in reversed(self.circuit.args):
  145. if isinstance(g, Gate):
  146. gates.append(g)
  147. elif isinstance(self.circuit, Gate):
  148. gates.append(self.circuit)
  149. return gates
  150. def _plot_gates(self):
  151. """Iterate through the gates and plot each of them."""
  152. for i, gate in enumerate(self._gates()):
  153. gate.plot_gate(self, i)
  154. def _measurements(self):
  155. """Return a dict {i:j} where i is the index of the wire that has
  156. been measured, and j is the gate where the wire is measured.
  157. """
  158. ismeasured = {}
  159. for i,g in enumerate(self._gates()):
  160. if getattr(g,'measurement',False):
  161. for target in g.targets:
  162. if target in ismeasured:
  163. if ismeasured[target] > i:
  164. ismeasured[target] = i
  165. else:
  166. ismeasured[target] = i
  167. return ismeasured
  168. def _finish(self):
  169. # Disable clipping to make panning work well for large circuits.
  170. for o in self._figure.findobj():
  171. o.set_clip_on(False)
  172. def one_qubit_box(self, t, gate_idx, wire_idx):
  173. """Draw a box for a single qubit gate."""
  174. x = self._gate_grid[gate_idx]
  175. y = self._wire_grid[wire_idx]
  176. self._axes.text(
  177. x, y, t,
  178. color='k',
  179. ha='center',
  180. va='center',
  181. bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
  182. size=self.fontsize
  183. )
  184. def two_qubit_box(self, t, gate_idx, wire_idx):
  185. """Draw a box for a two qubit gate. Doesn't work yet.
  186. """
  187. # x = self._gate_grid[gate_idx]
  188. # y = self._wire_grid[wire_idx]+0.5
  189. print(self._gate_grid)
  190. print(self._wire_grid)
  191. # unused:
  192. # obj = self._axes.text(
  193. # x, y, t,
  194. # color='k',
  195. # ha='center',
  196. # va='center',
  197. # bbox=dict(ec='k', fc='w', fill=True, lw=self.linewidth),
  198. # size=self.fontsize
  199. # )
  200. def control_line(self, gate_idx, min_wire, max_wire):
  201. """Draw a vertical control line."""
  202. xdata = (self._gate_grid[gate_idx], self._gate_grid[gate_idx])
  203. ydata = (self._wire_grid[min_wire], self._wire_grid[max_wire])
  204. line = Line2D(
  205. xdata, ydata,
  206. color='k',
  207. lw=self.linewidth
  208. )
  209. self._axes.add_line(line)
  210. def control_point(self, gate_idx, wire_idx):
  211. """Draw a control point."""
  212. x = self._gate_grid[gate_idx]
  213. y = self._wire_grid[wire_idx]
  214. radius = self.control_radius
  215. c = Circle(
  216. (x, y),
  217. radius*self.scale,
  218. ec='k',
  219. fc='k',
  220. fill=True,
  221. lw=self.linewidth
  222. )
  223. self._axes.add_patch(c)
  224. def not_point(self, gate_idx, wire_idx):
  225. """Draw a NOT gates as the circle with plus in the middle."""
  226. x = self._gate_grid[gate_idx]
  227. y = self._wire_grid[wire_idx]
  228. radius = self.not_radius
  229. c = Circle(
  230. (x, y),
  231. radius,
  232. ec='k',
  233. fc='w',
  234. fill=False,
  235. lw=self.linewidth
  236. )
  237. self._axes.add_patch(c)
  238. l = Line2D(
  239. (x, x), (y - radius, y + radius),
  240. color='k',
  241. lw=self.linewidth
  242. )
  243. self._axes.add_line(l)
  244. def swap_point(self, gate_idx, wire_idx):
  245. """Draw a swap point as a cross."""
  246. x = self._gate_grid[gate_idx]
  247. y = self._wire_grid[wire_idx]
  248. d = self.swap_delta
  249. l1 = Line2D(
  250. (x - d, x + d),
  251. (y - d, y + d),
  252. color='k',
  253. lw=self.linewidth
  254. )
  255. l2 = Line2D(
  256. (x - d, x + d),
  257. (y + d, y - d),
  258. color='k',
  259. lw=self.linewidth
  260. )
  261. self._axes.add_line(l1)
  262. self._axes.add_line(l2)
  263. def circuit_plot(c, nqubits, **kwargs):
  264. """Draw the circuit diagram for the circuit with nqubits.
  265. Parameters
  266. ==========
  267. c : circuit
  268. The circuit to plot. Should be a product of Gate instances.
  269. nqubits : int
  270. The number of qubits to include in the circuit. Must be at least
  271. as big as the largest `min_qubits`` of the gates.
  272. """
  273. return CircuitPlot(c, nqubits, **kwargs)
  274. def render_label(label, inits={}):
  275. """Slightly more flexible way to render labels.
  276. >>> from sympy.physics.quantum.circuitplot import render_label
  277. >>> render_label('q0')
  278. '$\\\\left|q0\\\\right\\\\rangle$'
  279. >>> render_label('q0', {'q0':'0'})
  280. '$\\\\left|q0\\\\right\\\\rangle=\\\\left|0\\\\right\\\\rangle$'
  281. """
  282. init = inits.get(label)
  283. if init:
  284. return r'$\left|%s\right\rangle=\left|%s\right\rangle$' % (label, init)
  285. return r'$\left|%s\right\rangle$' % label
  286. def labeller(n, symbol='q'):
  287. """Autogenerate labels for wires of quantum circuits.
  288. Parameters
  289. ==========
  290. n : int
  291. number of qubits in the circuit.
  292. symbol : string
  293. A character string to precede all gate labels. E.g. 'q_0', 'q_1', etc.
  294. >>> from sympy.physics.quantum.circuitplot import labeller
  295. >>> labeller(2)
  296. ['q_1', 'q_0']
  297. >>> labeller(3,'j')
  298. ['j_2', 'j_1', 'j_0']
  299. """
  300. return ['%s_%d' % (symbol,n-i-1) for i in range(n)]
  301. class Mz(OneQubitGate):
  302. """Mock-up of a z measurement gate.
  303. This is in circuitplot rather than gate.py because it's not a real
  304. gate, it just draws one.
  305. """
  306. measurement = True
  307. gate_name='Mz'
  308. gate_name_latex='M_z'
  309. class Mx(OneQubitGate):
  310. """Mock-up of an x measurement gate.
  311. This is in circuitplot rather than gate.py because it's not a real
  312. gate, it just draws one.
  313. """
  314. measurement = True
  315. gate_name='Mx'
  316. gate_name_latex='M_x'
  317. class CreateOneQubitGate(ManagedProperties):
  318. def __new__(mcl, name, latexname=None):
  319. if not latexname:
  320. latexname = name
  321. return BasicMeta.__new__(mcl, name + "Gate", (OneQubitGate,),
  322. {'gate_name': name, 'gate_name_latex': latexname})
  323. def CreateCGate(name, latexname=None):
  324. """Use a lexical closure to make a controlled gate.
  325. """
  326. if not latexname:
  327. latexname = name
  328. onequbitgate = CreateOneQubitGate(name, latexname)
  329. def ControlledGate(ctrls,target):
  330. return CGate(tuple(ctrls),onequbitgate(target))
  331. return ControlledGate