visualization.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313
  1. """
  2. Plotting (requires matplotlib)
  3. """
  4. from colorsys import hsv_to_rgb, hls_to_rgb
  5. from .libmp import NoConvergence
  6. from .libmp.backend import xrange
  7. class VisualizationMethods(object):
  8. plot_ignore = (ValueError, ArithmeticError, ZeroDivisionError, NoConvergence)
  9. def plot(ctx, f, xlim=[-5,5], ylim=None, points=200, file=None, dpi=None,
  10. singularities=[], axes=None):
  11. r"""
  12. Shows a simple 2D plot of a function `f(x)` or list of functions
  13. `[f_0(x), f_1(x), \ldots, f_n(x)]` over a given interval
  14. specified by *xlim*. Some examples::
  15. plot(lambda x: exp(x)*li(x), [1, 4])
  16. plot([cos, sin], [-4, 4])
  17. plot([fresnels, fresnelc], [-4, 4])
  18. plot([sqrt, cbrt], [-4, 4])
  19. plot(lambda t: zeta(0.5+t*j), [-20, 20])
  20. plot([floor, ceil, abs, sign], [-5, 5])
  21. Points where the function raises a numerical exception or
  22. returns an infinite value are removed from the graph.
  23. Singularities can also be excluded explicitly
  24. as follows (useful for removing erroneous vertical lines)::
  25. plot(cot, ylim=[-5, 5]) # bad
  26. plot(cot, ylim=[-5, 5], singularities=[-pi, 0, pi]) # good
  27. For parts where the function assumes complex values, the
  28. real part is plotted with dashes and the imaginary part
  29. is plotted with dots.
  30. .. note :: This function requires matplotlib (pylab).
  31. """
  32. if file:
  33. axes = None
  34. fig = None
  35. if not axes:
  36. import pylab
  37. fig = pylab.figure()
  38. axes = fig.add_subplot(111)
  39. if not isinstance(f, (tuple, list)):
  40. f = [f]
  41. a, b = xlim
  42. colors = ['b', 'r', 'g', 'm', 'k']
  43. for n, func in enumerate(f):
  44. x = ctx.arange(a, b, (b-a)/float(points))
  45. segments = []
  46. segment = []
  47. in_complex = False
  48. for i in xrange(len(x)):
  49. try:
  50. if i != 0:
  51. for sing in singularities:
  52. if x[i-1] <= sing and x[i] >= sing:
  53. raise ValueError
  54. v = func(x[i])
  55. if ctx.isnan(v) or abs(v) > 1e300:
  56. raise ValueError
  57. if hasattr(v, "imag") and v.imag:
  58. re = float(v.real)
  59. im = float(v.imag)
  60. if not in_complex:
  61. in_complex = True
  62. segments.append(segment)
  63. segment = []
  64. segment.append((float(x[i]), re, im))
  65. else:
  66. if in_complex:
  67. in_complex = False
  68. segments.append(segment)
  69. segment = []
  70. if hasattr(v, "real"):
  71. v = v.real
  72. segment.append((float(x[i]), v))
  73. except ctx.plot_ignore:
  74. if segment:
  75. segments.append(segment)
  76. segment = []
  77. if segment:
  78. segments.append(segment)
  79. for segment in segments:
  80. x = [s[0] for s in segment]
  81. y = [s[1] for s in segment]
  82. if not x:
  83. continue
  84. c = colors[n % len(colors)]
  85. if len(segment[0]) == 3:
  86. z = [s[2] for s in segment]
  87. axes.plot(x, y, '--'+c, linewidth=3)
  88. axes.plot(x, z, ':'+c, linewidth=3)
  89. else:
  90. axes.plot(x, y, c, linewidth=3)
  91. axes.set_xlim([float(_) for _ in xlim])
  92. if ylim:
  93. axes.set_ylim([float(_) for _ in ylim])
  94. axes.set_xlabel('x')
  95. axes.set_ylabel('f(x)')
  96. axes.grid(True)
  97. if fig:
  98. if file:
  99. pylab.savefig(file, dpi=dpi)
  100. else:
  101. pylab.show()
  102. def default_color_function(ctx, z):
  103. if ctx.isinf(z):
  104. return (1.0, 1.0, 1.0)
  105. if ctx.isnan(z):
  106. return (0.5, 0.5, 0.5)
  107. pi = 3.1415926535898
  108. a = (float(ctx.arg(z)) + ctx.pi) / (2*ctx.pi)
  109. a = (a + 0.5) % 1.0
  110. b = 1.0 - float(1/(1.0+abs(z)**0.3))
  111. return hls_to_rgb(a, b, 0.8)
  112. blue_orange_colors = [
  113. (-1.0, (0.0, 0.0, 0.0)),
  114. (-0.95, (0.1, 0.2, 0.5)), # dark blue
  115. (-0.5, (0.0, 0.5, 1.0)), # blueish
  116. (-0.05, (0.4, 0.8, 0.8)), # cyanish
  117. ( 0.0, (1.0, 1.0, 1.0)),
  118. ( 0.05, (1.0, 0.9, 0.3)), # yellowish
  119. ( 0.5, (0.9, 0.5, 0.0)), # orangeish
  120. ( 0.95, (0.7, 0.1, 0.0)), # redish
  121. ( 1.0, (0.0, 0.0, 0.0)),
  122. ( 2.0, (0.0, 0.0, 0.0)),
  123. ]
  124. def phase_color_function(ctx, z):
  125. if ctx.isinf(z):
  126. return (1.0, 1.0, 1.0)
  127. if ctx.isnan(z):
  128. return (0.5, 0.5, 0.5)
  129. pi = 3.1415926535898
  130. w = float(ctx.arg(z)) / pi
  131. w = max(min(w, 1.0), -1.0)
  132. for i in range(1,len(blue_orange_colors)):
  133. if blue_orange_colors[i][0] > w:
  134. a, (ra, ga, ba) = blue_orange_colors[i-1]
  135. b, (rb, gb, bb) = blue_orange_colors[i]
  136. s = (w-a) / (b-a)
  137. return ra+(rb-ra)*s, ga+(gb-ga)*s, ba+(bb-ba)*s
  138. def cplot(ctx, f, re=[-5,5], im=[-5,5], points=2000, color=None,
  139. verbose=False, file=None, dpi=None, axes=None):
  140. """
  141. Plots the given complex-valued function *f* over a rectangular part
  142. of the complex plane specified by the pairs of intervals *re* and *im*.
  143. For example::
  144. cplot(lambda z: z, [-2, 2], [-10, 10])
  145. cplot(exp)
  146. cplot(zeta, [0, 1], [0, 50])
  147. By default, the complex argument (phase) is shown as color (hue) and
  148. the magnitude is show as brightness. You can also supply a
  149. custom color function (*color*). This function should take a
  150. complex number as input and return an RGB 3-tuple containing
  151. floats in the range 0.0-1.0.
  152. Alternatively, you can select a builtin color function by passing
  153. a string as *color*:
  154. * "default" - default color scheme
  155. * "phase" - a color scheme that only renders the phase of the function,
  156. with white for positive reals, black for negative reals, gold in the
  157. upper half plane, and blue in the lower half plane.
  158. To obtain a sharp image, the number of points may need to be
  159. increased to 100,000 or thereabout. Since evaluating the
  160. function that many times is likely to be slow, the 'verbose'
  161. option is useful to display progress.
  162. .. note :: This function requires matplotlib (pylab).
  163. """
  164. if color is None or color == "default":
  165. color = ctx.default_color_function
  166. if color == "phase":
  167. color = ctx.phase_color_function
  168. import pylab
  169. if file:
  170. axes = None
  171. fig = None
  172. if not axes:
  173. fig = pylab.figure()
  174. axes = fig.add_subplot(111)
  175. rea, reb = re
  176. ima, imb = im
  177. dre = reb - rea
  178. dim = imb - ima
  179. M = int(ctx.sqrt(points*dre/dim)+1)
  180. N = int(ctx.sqrt(points*dim/dre)+1)
  181. x = pylab.linspace(rea, reb, M)
  182. y = pylab.linspace(ima, imb, N)
  183. # Note: we have to be careful to get the right rotation.
  184. # Test with these plots:
  185. # cplot(lambda z: z if z.real < 0 else 0)
  186. # cplot(lambda z: z if z.imag < 0 else 0)
  187. w = pylab.zeros((N, M, 3))
  188. for n in xrange(N):
  189. for m in xrange(M):
  190. z = ctx.mpc(x[m], y[n])
  191. try:
  192. v = color(f(z))
  193. except ctx.plot_ignore:
  194. v = (0.5, 0.5, 0.5)
  195. w[n,m] = v
  196. if verbose:
  197. print(str(n) + ' of ' + str(N))
  198. rea, reb, ima, imb = [float(_) for _ in [rea, reb, ima, imb]]
  199. axes.imshow(w, extent=(rea, reb, ima, imb), origin='lower')
  200. axes.set_xlabel('Re(z)')
  201. axes.set_ylabel('Im(z)')
  202. if fig:
  203. if file:
  204. pylab.savefig(file, dpi=dpi)
  205. else:
  206. pylab.show()
  207. def splot(ctx, f, u=[-5,5], v=[-5,5], points=100, keep_aspect=True, \
  208. wireframe=False, file=None, dpi=None, axes=None):
  209. """
  210. Plots the surface defined by `f`.
  211. If `f` returns a single component, then this plots the surface
  212. defined by `z = f(x,y)` over the rectangular domain with
  213. `x = u` and `y = v`.
  214. If `f` returns three components, then this plots the parametric
  215. surface `x, y, z = f(u,v)` over the pairs of intervals `u` and `v`.
  216. For example, to plot a simple function::
  217. >>> from mpmath import *
  218. >>> f = lambda x, y: sin(x+y)*cos(y)
  219. >>> splot(f, [-pi,pi], [-pi,pi]) # doctest: +SKIP
  220. Plotting a donut::
  221. >>> r, R = 1, 2.5
  222. >>> f = lambda u, v: [r*cos(u), (R+r*sin(u))*cos(v), (R+r*sin(u))*sin(v)]
  223. >>> splot(f, [0, 2*pi], [0, 2*pi]) # doctest: +SKIP
  224. .. note :: This function requires matplotlib (pylab) 0.98.5.3 or higher.
  225. """
  226. import pylab
  227. import mpl_toolkits.mplot3d as mplot3d
  228. if file:
  229. axes = None
  230. fig = None
  231. if not axes:
  232. fig = pylab.figure()
  233. axes = mplot3d.axes3d.Axes3D(fig)
  234. ua, ub = u
  235. va, vb = v
  236. du = ub - ua
  237. dv = vb - va
  238. if not isinstance(points, (list, tuple)):
  239. points = [points, points]
  240. M, N = points
  241. u = pylab.linspace(ua, ub, M)
  242. v = pylab.linspace(va, vb, N)
  243. x, y, z = [pylab.zeros((M, N)) for i in xrange(3)]
  244. xab, yab, zab = [[0, 0] for i in xrange(3)]
  245. for n in xrange(N):
  246. for m in xrange(M):
  247. fdata = f(ctx.convert(u[m]), ctx.convert(v[n]))
  248. try:
  249. x[m,n], y[m,n], z[m,n] = fdata
  250. except TypeError:
  251. x[m,n], y[m,n], z[m,n] = u[m], v[n], fdata
  252. for c, cab in [(x[m,n], xab), (y[m,n], yab), (z[m,n], zab)]:
  253. if c < cab[0]:
  254. cab[0] = c
  255. if c > cab[1]:
  256. cab[1] = c
  257. if wireframe:
  258. axes.plot_wireframe(x, y, z, rstride=4, cstride=4)
  259. else:
  260. axes.plot_surface(x, y, z, rstride=4, cstride=4)
  261. axes.set_xlabel('x')
  262. axes.set_ylabel('y')
  263. axes.set_zlabel('z')
  264. if keep_aspect:
  265. dx, dy, dz = [cab[1] - cab[0] for cab in [xab, yab, zab]]
  266. maxd = max(dx, dy, dz)
  267. if dx < maxd:
  268. delta = maxd - dx
  269. axes.set_xlim3d(xab[0] - delta / 2.0, xab[1] + delta / 2.0)
  270. if dy < maxd:
  271. delta = maxd - dy
  272. axes.set_ylim3d(yab[0] - delta / 2.0, yab[1] + delta / 2.0)
  273. if dz < maxd:
  274. delta = maxd - dz
  275. axes.set_zlim3d(zab[0] - delta / 2.0, zab[1] + delta / 2.0)
  276. if fig:
  277. if file:
  278. pylab.savefig(file, dpi=dpi)
  279. else:
  280. pylab.show()
  281. VisualizationMethods.plot = plot
  282. VisualizationMethods.default_color_function = default_color_function
  283. VisualizationMethods.phase_color_function = phase_color_function
  284. VisualizationMethods.cplot = cplot
  285. VisualizationMethods.splot = splot