plot.py 89 KB


  1. """Plotting module for SymPy.
  2. A plot is represented by the ``Plot`` class that contains a reference to the
  3. backend and a list of the data series to be plotted. The data series are
  4. instances of classes meant to simplify getting points and meshes from SymPy
  5. expressions. ``plot_backends`` is a dictionary with all the backends.
  6. This module gives only the essential. For all the fancy stuff use directly
  7. the backend. You can get the backend wrapper for every plot from the
  8. ``_backend`` attribute. Moreover the data series classes have various useful
  9. methods like ``get_points``, ``get_meshes``, etc, that may
  10. be useful if you wish to use another plotting library.
  11. Especially if you need publication ready graphs and this module is not enough
  12. for you - just get the ``_backend`` attribute and add whatever you want
  13. directly to it. In the case of matplotlib (the common way to graph data in
  14. python) just copy ``_backend.fig`` which is the figure and ``_backend.ax``
  15. which is the axis and work on them as you would on any other matplotlib object.
  16. Simplicity of code takes much greater importance than performance. Don't use it
  17. if you care at all about performance. A new backend instance is initialized
  18. every time you call ``show()`` and the old one is left to the garbage collector.
  19. """
  20. from collections.abc import Callable
  21. from sympy.core.containers import Tuple
  22. from sympy.core.expr import Expr
  23. from sympy.core.symbol import (Dummy, Symbol)
  24. from sympy.core.sympify import sympify
  25. from sympy.external import import_module
  26. from sympy.core.function import arity
  27. from sympy.utilities.iterables import is_sequence
  28. from .experimental_lambdify import (vectorized_lambdify, lambdify)
  29. from sympy.utilities.exceptions import sympy_deprecation_warning
  30. # N.B.
  31. # When changing the minimum module version for matplotlib, please change
  32. # the same in the `SymPyDocTestFinder`` in `sympy/testing/runtests.py`
  33. # Backend specific imports - textplot
  34. from sympy.plotting.textplot import textplot
  35. # Global variable
  36. # Set to False when running tests / doctests so that the plots don't show.
  37. _show = True
  38. def unset_show():
  39. """
  40. Disable show(). For use in the tests.
  41. """
  42. global _show
  43. _show = False
  44. ##############################################################################
  45. # The public interface
  46. ##############################################################################
  47. class Plot:
  48. """The central class of the plotting module.
  49. Explanation
  50. ===========
  51. For interactive work the function ``plot`` is better suited.
  52. This class permits the plotting of SymPy expressions using numerous
  53. backends (matplotlib, textplot, the old pyglet module for sympy, Google
  54. charts api, etc).
  55. The figure can contain an arbitrary number of plots of SymPy expressions,
  56. lists of coordinates of points, etc. Plot has a private attribute _series that
  57. contains all data series to be plotted (expressions for lines or surfaces,
  58. lists of points, etc (all subclasses of BaseSeries)). Those data series are
  59. instances of classes not imported by ``from sympy import *``.
  60. The customization of the figure is on two levels. Global options that
  61. concern the figure as a whole (eg title, xlabel, scale, etc) and
  62. per-data series options (eg name) and aesthetics (eg. color, point shape,
  63. line type, etc.).
  64. The difference between options and aesthetics is that an aesthetic can be
  65. a function of the coordinates (or parameters in a parametric plot). The
  66. supported values for an aesthetic are:
  67. - None (the backend uses default values)
  68. - a constant
  69. - a function of one variable (the first coordinate or parameter)
  70. - a function of two variables (the first and second coordinate or
  71. parameters)
  72. - a function of three variables (only in nonparametric 3D plots)
  73. Their implementation depends on the backend so they may not work in some
  74. backends.
  75. If the plot is parametric and the arity of the aesthetic function permits
  76. it the aesthetic is calculated over parameters and not over coordinates.
  77. If the arity does not permit calculation over parameters the calculation is
  78. done over coordinates.
  79. Only cartesian coordinates are supported for the moment, but you can use
  80. the parametric plots to plot in polar, spherical and cylindrical
  81. coordinates.
  82. The arguments for the constructor Plot must be subclasses of BaseSeries.
  83. Any global option can be specified as a keyword argument.
  84. The global options for a figure are:
  85. - title : str
  86. - xlabel : str
  87. - ylabel : str
  88. - zlabel : str
  89. - legend : bool
  90. - xscale : {'linear', 'log'}
  91. - yscale : {'linear', 'log'}
  92. - axis : bool
  93. - axis_center : tuple of two floats or {'center', 'auto'}
  94. - xlim : tuple of two floats
  95. - ylim : tuple of two floats
  96. - aspect_ratio : tuple of two floats or {'auto'}
  97. - autoscale : bool
  98. - margin : float in [0, 1]
  99. - backend : {'default', 'matplotlib', 'text'} or a subclass of BaseBackend
  100. - size : optional tuple of two floats, (width, height); default: None
  101. The per data series options and aesthetics are:
  102. There are none in the base series. See below for options for subclasses.
  103. Some data series support additional aesthetics or options:
  104. ListSeries, LineOver1DRangeSeries, Parametric2DLineSeries,
  105. Parametric3DLineSeries support the following:
  106. Aesthetics:
  107. - line_color : string, or float, or function, optional
  108. Specifies the color for the plot, which depends on the backend being
  109. used.
  110. For example, if ``MatplotlibBackend`` is being used, then
  111. Matplotlib string colors are acceptable ("red", "r", "cyan", "c", ...).
  112. Alternatively, we can use a float number `0 < color < 1` wrapped in a
  113. string (for example, `line_color="0.5"`) to specify grayscale colors.
  114. Alternatively, We can specify a function returning a single
  115. float value: this will be used to apply a color-loop (for example,
  116. `line_color=lambda x: math.cos(x)`).
  117. Note that by setting line_color, it would be applied simultaneously
  118. to all the series.
  119. options:
  120. - label : str
  121. - steps : bool
  122. - integers_only : bool
  123. SurfaceOver2DRangeSeries, ParametricSurfaceSeries support the following:
  124. aesthetics:
  125. - surface_color : function which returns a float.
  126. """
  127. def __init__(self, *args,
  128. title=None, xlabel=None, ylabel=None, zlabel=None, aspect_ratio='auto',
  129. xlim=None, ylim=None, axis_center='auto', axis=True,
  130. xscale='linear', yscale='linear', legend=False, autoscale=True,
  131. margin=0, annotations=None, markers=None, rectangles=None,
  132. fill=None, backend='default', size=None, **kwargs):
  133. super().__init__()
  134. # Options for the graph as a whole.
  135. # The possible values for each option are described in the docstring of
  136. # Plot. They are based purely on convention, no checking is done.
  137. self.title = title
  138. self.xlabel = xlabel
  139. self.ylabel = ylabel
  140. self.zlabel = zlabel
  141. self.aspect_ratio = aspect_ratio
  142. self.axis_center = axis_center
  143. self.axis = axis
  144. self.xscale = xscale
  145. self.yscale = yscale
  146. self.legend = legend
  147. self.autoscale = autoscale
  148. self.margin = margin
  149. self.annotations = annotations
  150. self.markers = markers
  151. self.rectangles = rectangles
  152. self.fill = fill
  153. # Contains the data objects to be plotted. The backend should be smart
  154. # enough to iterate over this list.
  155. self._series = []
  156. self._series.extend(args)
  157. # The backend type. On every show() a new backend instance is created
  158. # in self._backend which is tightly coupled to the Plot instance
  159. # (thanks to the parent attribute of the backend).
  160. if isinstance(backend, str):
  161. self.backend = plot_backends[backend]
  162. elif (type(backend) == type) and issubclass(backend, BaseBackend):
  163. self.backend = backend
  164. else:
  165. raise TypeError(
  166. "backend must be either a string or a subclass of BaseBackend")
  167. is_real = \
  168. lambda lim: all(getattr(i, 'is_real', True) for i in lim)
  169. is_finite = \
  170. lambda lim: all(getattr(i, 'is_finite', True) for i in lim)
  171. # reduce code repetition
  172. def check_and_set(t_name, t):
  173. if t:
  174. if not is_real(t):
  175. raise ValueError(
  176. "All numbers from {}={} must be real".format(t_name, t))
  177. if not is_finite(t):
  178. raise ValueError(
  179. "All numbers from {}={} must be finite".format(t_name, t))
  180. setattr(self, t_name, (float(t[0]), float(t[1])))
  181. self.xlim = None
  182. check_and_set("xlim", xlim)
  183. self.ylim = None
  184. check_and_set("ylim", ylim)
  185. self.size = None
  186. check_and_set("size", size)
  187. def show(self):
  188. # TODO move this to the backend (also for save)
  189. if hasattr(self, '_backend'):
  190. self._backend.close()
  191. self._backend = self.backend(self)
  192. self._backend.show()
  193. def save(self, path):
  194. if hasattr(self, '_backend'):
  195. self._backend.close()
  196. self._backend = self.backend(self)
  197. self._backend.save(path)
  198. def __str__(self):
  199. series_strs = [('[%d]: ' % i) + str(s)
  200. for i, s in enumerate(self._series)]
  201. return 'Plot object containing:\n' + '\n'.join(series_strs)
  202. def __getitem__(self, index):
  203. return self._series[index]
  204. def __setitem__(self, index, *args):
  205. if len(args) == 1 and isinstance(args[0], BaseSeries):
  206. self._series[index] = args
  207. def __delitem__(self, index):
  208. del self._series[index]
  209. def append(self, arg):
  210. """Adds an element from a plot's series to an existing plot.
  211. Examples
  212. ========
  213. Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
  214. second plot's first series object to the first, use the
  215. ``append`` method, like so:
  216. .. plot::
  217. :format: doctest
  218. :include-source: True
  219. >>> from sympy import symbols
  220. >>> from sympy.plotting import plot
  221. >>> x = symbols('x')
  222. >>> p1 = plot(x*x, show=False)
  223. >>> p2 = plot(x, show=False)
  224. >>> p1.append(p2[0])
  225. >>> p1
  226. Plot object containing:
  227. [0]: cartesian line: x**2 for x over (-10.0, 10.0)
  228. [1]: cartesian line: x for x over (-10.0, 10.0)
  229. >>> p1.show()
  230. See Also
  231. ========
  232. extend
  233. """
  234. if isinstance(arg, BaseSeries):
  235. self._series.append(arg)
  236. else:
  237. raise TypeError('Must specify element of plot to append.')
  238. def extend(self, arg):
  239. """Adds all series from another plot.
  240. Examples
  241. ========
  242. Consider two ``Plot`` objects, ``p1`` and ``p2``. To add the
  243. second plot to the first, use the ``extend`` method, like so:
  244. .. plot::
  245. :format: doctest
  246. :include-source: True
  247. >>> from sympy import symbols
  248. >>> from sympy.plotting import plot
  249. >>> x = symbols('x')
  250. >>> p1 = plot(x**2, show=False)
  251. >>> p2 = plot(x, -x, show=False)
  252. >>> p1.extend(p2)
  253. >>> p1
  254. Plot object containing:
  255. [0]: cartesian line: x**2 for x over (-10.0, 10.0)
  256. [1]: cartesian line: x for x over (-10.0, 10.0)
  257. [2]: cartesian line: -x for x over (-10.0, 10.0)
  258. >>> p1.show()
  259. """
  260. if isinstance(arg, Plot):
  261. self._series.extend(arg._series)
  262. elif is_sequence(arg):
  263. self._series.extend(arg)
  264. else:
  265. raise TypeError('Expecting Plot or sequence of BaseSeries')
  266. class PlotGrid:
  267. """This class helps to plot subplots from already created SymPy plots
  268. in a single figure.
  269. Examples
  270. ========
  271. .. plot::
  272. :context: close-figs
  273. :format: doctest
  274. :include-source: True
  275. >>> from sympy import symbols
  276. >>> from sympy.plotting import plot, plot3d, PlotGrid
  277. >>> x, y = symbols('x, y')
  278. >>> p1 = plot(x, x**2, x**3, (x, -5, 5))
  279. >>> p2 = plot((x**2, (x, -6, 6)), (x, (x, -5, 5)))
  280. >>> p3 = plot(x**3, (x, -5, 5))
  281. >>> p4 = plot3d(x*y, (x, -5, 5), (y, -5, 5))
  282. Plotting vertically in a single line:
  283. .. plot::
  284. :context: close-figs
  285. :format: doctest
  286. :include-source: True
  287. >>> PlotGrid(2, 1, p1, p2)
  288. PlotGrid object containing:
  289. Plot[0]:Plot object containing:
  290. [0]: cartesian line: x for x over (-5.0, 5.0)
  291. [1]: cartesian line: x**2 for x over (-5.0, 5.0)
  292. [2]: cartesian line: x**3 for x over (-5.0, 5.0)
  293. Plot[1]:Plot object containing:
  294. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  295. [1]: cartesian line: x for x over (-5.0, 5.0)
  296. Plotting horizontally in a single line:
  297. .. plot::
  298. :context: close-figs
  299. :format: doctest
  300. :include-source: True
  301. >>> PlotGrid(1, 3, p2, p3, p4)
  302. PlotGrid object containing:
  303. Plot[0]:Plot object containing:
  304. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  305. [1]: cartesian line: x for x over (-5.0, 5.0)
  306. Plot[1]:Plot object containing:
  307. [0]: cartesian line: x**3 for x over (-5.0, 5.0)
  308. Plot[2]:Plot object containing:
  309. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  310. Plotting in a grid form:
  311. .. plot::
  312. :context: close-figs
  313. :format: doctest
  314. :include-source: True
  315. >>> PlotGrid(2, 2, p1, p2, p3, p4)
  316. PlotGrid object containing:
  317. Plot[0]:Plot object containing:
  318. [0]: cartesian line: x for x over (-5.0, 5.0)
  319. [1]: cartesian line: x**2 for x over (-5.0, 5.0)
  320. [2]: cartesian line: x**3 for x over (-5.0, 5.0)
  321. Plot[1]:Plot object containing:
  322. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  323. [1]: cartesian line: x for x over (-5.0, 5.0)
  324. Plot[2]:Plot object containing:
  325. [0]: cartesian line: x**3 for x over (-5.0, 5.0)
  326. Plot[3]:Plot object containing:
  327. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  328. """
  329. def __init__(self, nrows, ncolumns, *args, show=True, size=None, **kwargs):
  330. """
  331. Parameters
  332. ==========
  333. nrows :
  334. The number of rows that should be in the grid of the
  335. required subplot.
  336. ncolumns :
  337. The number of columns that should be in the grid
  338. of the required subplot.
  339. nrows and ncolumns together define the required grid.
  340. Arguments
  341. =========
  342. A list of predefined plot objects entered in a row-wise sequence
  343. i.e. plot objects which are to be in the top row of the required
  344. grid are written first, then the second row objects and so on
  345. Keyword arguments
  346. =================
  347. show : Boolean
  348. The default value is set to ``True``. Set show to ``False`` and
  349. the function will not display the subplot. The returned instance
  350. of the ``PlotGrid`` class can then be used to save or display the
  351. plot by calling the ``save()`` and ``show()`` methods
  352. respectively.
  353. size : (float, float), optional
  354. A tuple in the form (width, height) in inches to specify the size of
  355. the overall figure. The default value is set to ``None``, meaning
  356. the size will be set by the default backend.
  357. """
  358. self.nrows = nrows
  359. self.ncolumns = ncolumns
  360. self._series = []
  361. self.args = args
  362. for arg in args:
  363. self._series.append(arg._series)
  364. self.backend = DefaultBackend
  365. self.size = size
  366. if show:
  367. self.show()
  368. def show(self):
  369. if hasattr(self, '_backend'):
  370. self._backend.close()
  371. self._backend = self.backend(self)
  372. self._backend.show()
  373. def save(self, path):
  374. if hasattr(self, '_backend'):
  375. self._backend.close()
  376. self._backend = self.backend(self)
  377. self._backend.save(path)
  378. def __str__(self):
  379. plot_strs = [('Plot[%d]:' % i) + str(plot)
  380. for i, plot in enumerate(self.args)]
  381. return 'PlotGrid object containing:\n' + '\n'.join(plot_strs)
  382. ##############################################################################
  383. # Data Series
  384. ##############################################################################
  385. #TODO more general way to calculate aesthetics (see get_color_array)
  386. ### The base class for all series
  387. class BaseSeries:
  388. """Base class for the data objects containing stuff to be plotted.
  389. Explanation
  390. ===========
  391. The backend should check if it supports the data series that it's given.
  392. (eg TextBackend supports only LineOver1DRange).
  393. It's the backend responsibility to know how to use the class of
  394. data series that it's given.
  395. Some data series classes are grouped (using a class attribute like is_2Dline)
  396. according to the api they present (based only on convention). The backend is
  397. not obliged to use that api (eg. The LineOver1DRange belongs to the
  398. is_2Dline group and presents the get_points method, but the
  399. TextBackend does not use the get_points method).
  400. """
  401. # Some flags follow. The rationale for using flags instead of checking base
  402. # classes is that setting multiple flags is simpler than multiple
  403. # inheritance.
  404. is_2Dline = False
  405. # Some of the backends expect:
  406. # - get_points returning 1D np.arrays list_x, list_y
  407. # - get_color_array returning 1D np.array (done in Line2DBaseSeries)
  408. # with the colors calculated at the points from get_points
  409. is_3Dline = False
  410. # Some of the backends expect:
  411. # - get_points returning 1D np.arrays list_x, list_y, list_y
  412. # - get_color_array returning 1D np.array (done in Line2DBaseSeries)
  413. # with the colors calculated at the points from get_points
  414. is_3Dsurface = False
  415. # Some of the backends expect:
  416. # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays)
  417. # - get_points an alias for get_meshes
  418. is_contour = False
  419. # Some of the backends expect:
  420. # - get_meshes returning mesh_x, mesh_y, mesh_z (2D np.arrays)
  421. # - get_points an alias for get_meshes
  422. is_implicit = False
  423. # Some of the backends expect:
  424. # - get_meshes returning mesh_x (1D array), mesh_y(1D array,
  425. # mesh_z (2D np.arrays)
  426. # - get_points an alias for get_meshes
  427. # Different from is_contour as the colormap in backend will be
  428. # different
  429. is_parametric = False
  430. # The calculation of aesthetics expects:
  431. # - get_parameter_points returning one or two np.arrays (1D or 2D)
  432. # used for calculation aesthetics
  433. def __init__(self):
  434. super().__init__()
  435. @property
  436. def is_3D(self):
  437. flags3D = [
  438. self.is_3Dline,
  439. self.is_3Dsurface
  440. ]
  441. return any(flags3D)
  442. @property
  443. def is_line(self):
  444. flagslines = [
  445. self.is_2Dline,
  446. self.is_3Dline
  447. ]
  448. return any(flagslines)
  449. ### 2D lines
  450. class Line2DBaseSeries(BaseSeries):
  451. """A base class for 2D lines.
  452. - adding the label, steps and only_integers options
  453. - making is_2Dline true
  454. - defining get_segments and get_color_array
  455. """
  456. is_2Dline = True
  457. _dim = 2
  458. def __init__(self):
  459. super().__init__()
  460. self.label = None
  461. self.steps = False
  462. self.only_integers = False
  463. self.line_color = None
  464. def get_data(self):
  465. """ Return lists of coordinates for plotting the line.
  466. Returns
  467. =======
  468. x: list
  469. List of x-coordinates
  470. y: list
  471. List of y-coordinates
  472. y: list
  473. List of z-coordinates in case of Parametric3DLineSeries
  474. """
  475. np = import_module('numpy')
  476. points = self.get_points()
  477. if self.steps is True:
  478. if len(points) == 2:
  479. x = np.array((points[0], points[0])).T.flatten()[1:]
  480. y = np.array((points[1], points[1])).T.flatten()[:-1]
  481. points = (x, y)
  482. else:
  483. x = np.repeat(points[0], 3)[2:]
  484. y = np.repeat(points[1], 3)[:-2]
  485. z = np.repeat(points[2], 3)[1:-1]
  486. points = (x, y, z)
  487. return points
  488. def get_segments(self):
  489. sympy_deprecation_warning(
  490. """
  491. The Line2DBaseSeries.get_segments() method is deprecated.
  492. Instead, use the MatplotlibBackend.get_segments() method, or use
  493. The get_points() or get_data() methods.
  494. """,
  495. deprecated_since_version="1.9",
  496. active_deprecations_target="deprecated-get-segments")
  497. np = import_module('numpy')
  498. points = type(self).get_data(self)
  499. points = np.ma.array(points).T.reshape(-1, 1, self._dim)
  500. return np.ma.concatenate([points[:-1], points[1:]], axis=1)
  501. def get_color_array(self):
  502. np = import_module('numpy')
  503. c = self.line_color
  504. if hasattr(c, '__call__'):
  505. f = np.vectorize(c)
  506. nargs = arity(c)
  507. if nargs == 1 and self.is_parametric:
  508. x = self.get_parameter_points()
  509. return f(centers_of_segments(x))
  510. else:
  511. variables = list(map(centers_of_segments, self.get_points()))
  512. if nargs == 1:
  513. return f(variables[0])
  514. elif nargs == 2:
  515. return f(*variables[:2])
  516. else: # only if the line is 3D (otherwise raises an error)
  517. return f(*variables)
  518. else:
  519. return c*np.ones(self.nb_of_points)
  520. class List2DSeries(Line2DBaseSeries):
  521. """Representation for a line consisting of list of points."""
  522. def __init__(self, list_x, list_y):
  523. np = import_module('numpy')
  524. super().__init__()
  525. self.list_x = np.array(list_x)
  526. self.list_y = np.array(list_y)
  527. self.label = 'list'
  528. def __str__(self):
  529. return 'list plot'
  530. def get_points(self):
  531. return (self.list_x, self.list_y)
  532. class LineOver1DRangeSeries(Line2DBaseSeries):
  533. """Representation for a line consisting of a SymPy expression over a range."""
  534. def __init__(self, expr, var_start_end, **kwargs):
  535. super().__init__()
  536. self.expr = sympify(expr)
  537. self.label = kwargs.get('label', None) or str(self.expr)
  538. self.var = sympify(var_start_end[0])
  539. self.start = float(var_start_end[1])
  540. self.end = float(var_start_end[2])
  541. self.nb_of_points = kwargs.get('nb_of_points', 300)
  542. self.adaptive = kwargs.get('adaptive', True)
  543. self.depth = kwargs.get('depth', 12)
  544. self.line_color = kwargs.get('line_color', None)
  545. self.xscale = kwargs.get('xscale', 'linear')
  546. def __str__(self):
  547. return 'cartesian line: %s for %s over %s' % (
  548. str(self.expr), str(self.var), str((self.start, self.end)))
  549. def get_points(self):
  550. """ Return lists of coordinates for plotting. Depending on the
  551. `adaptive` option, this function will either use an adaptive algorithm
  552. or it will uniformly sample the expression over the provided range.
  553. Returns
  554. =======
  555. x: list
  556. List of x-coordinates
  557. y: list
  558. List of y-coordinates
  559. Explanation
  560. ===========
  561. The adaptive sampling is done by recursively checking if three
  562. points are almost collinear. If they are not collinear, then more
  563. points are added between those points.
  564. References
  565. ==========
  566. .. [1] Adaptive polygonal approximation of parametric curves,
  567. Luiz Henrique de Figueiredo.
  568. """
  569. if self.only_integers or not self.adaptive:
  570. return self._uniform_sampling()
  571. else:
  572. f = lambdify([self.var], self.expr)
  573. x_coords = []
  574. y_coords = []
  575. np = import_module('numpy')
  576. def sample(p, q, depth):
  577. """ Samples recursively if three points are almost collinear.
  578. For depth < 6, points are added irrespective of whether they
  579. satisfy the collinearity condition or not. The maximum depth
  580. allowed is 12.
  581. """
  582. # Randomly sample to avoid aliasing.
  583. random = 0.45 + np.random.rand() * 0.1
  584. if self.xscale == 'log':
  585. xnew = 10**(np.log10(p[0]) + random * (np.log10(q[0]) -
  586. np.log10(p[0])))
  587. else:
  588. xnew = p[0] + random * (q[0] - p[0])
  589. ynew = f(xnew)
  590. new_point = np.array([xnew, ynew])
  591. # Maximum depth
  592. if depth > self.depth:
  593. x_coords.append(q[0])
  594. y_coords.append(q[1])
  595. # Sample irrespective of whether the line is flat till the
  596. # depth of 6. We are not using linspace to avoid aliasing.
  597. elif depth < 6:
  598. sample(p, new_point, depth + 1)
  599. sample(new_point, q, depth + 1)
  600. # Sample ten points if complex values are encountered
  601. # at both ends. If there is a real value in between, then
  602. # sample those points further.
  603. elif p[1] is None and q[1] is None:
  604. if self.xscale == 'log':
  605. xarray = np.logspace(p[0], q[0], 10)
  606. else:
  607. xarray = np.linspace(p[0], q[0], 10)
  608. yarray = list(map(f, xarray))
  609. if not all(y is None for y in yarray):
  610. for i in range(len(yarray) - 1):
  611. if not (yarray[i] is None and yarray[i + 1] is None):
  612. sample([xarray[i], yarray[i]],
  613. [xarray[i + 1], yarray[i + 1]], depth + 1)
  614. # Sample further if one of the end points in None (i.e. a
  615. # complex value) or the three points are not almost collinear.
  616. elif (p[1] is None or q[1] is None or new_point[1] is None
  617. or not flat(p, new_point, q)):
  618. sample(p, new_point, depth + 1)
  619. sample(new_point, q, depth + 1)
  620. else:
  621. x_coords.append(q[0])
  622. y_coords.append(q[1])
  623. f_start = f(self.start)
  624. f_end = f(self.end)
  625. x_coords.append(self.start)
  626. y_coords.append(f_start)
  627. sample(np.array([self.start, f_start]),
  628. np.array([self.end, f_end]), 0)
  629. return (x_coords, y_coords)
  630. def _uniform_sampling(self):
  631. np = import_module('numpy')
  632. if self.only_integers is True:
  633. if self.xscale == 'log':
  634. list_x = np.logspace(int(self.start), int(self.end),
  635. num=int(self.end) - int(self.start) + 1)
  636. else:
  637. list_x = np.linspace(int(self.start), int(self.end),
  638. num=int(self.end) - int(self.start) + 1)
  639. else:
  640. if self.xscale == 'log':
  641. list_x = np.logspace(self.start, self.end, num=self.nb_of_points)
  642. else:
  643. list_x = np.linspace(self.start, self.end, num=self.nb_of_points)
  644. f = vectorized_lambdify([self.var], self.expr)
  645. list_y = f(list_x)
  646. return (list_x, list_y)
  647. class Parametric2DLineSeries(Line2DBaseSeries):
  648. """Representation for a line consisting of two parametric SymPy expressions
  649. over a range."""
  650. is_parametric = True
  651. def __init__(self, expr_x, expr_y, var_start_end, **kwargs):
  652. super().__init__()
  653. self.expr_x = sympify(expr_x)
  654. self.expr_y = sympify(expr_y)
  655. self.label = kwargs.get('label', None) or \
  656. "(%s, %s)" % (str(self.expr_x), str(self.expr_y))
  657. self.var = sympify(var_start_end[0])
  658. self.start = float(var_start_end[1])
  659. self.end = float(var_start_end[2])
  660. self.nb_of_points = kwargs.get('nb_of_points', 300)
  661. self.adaptive = kwargs.get('adaptive', True)
  662. self.depth = kwargs.get('depth', 12)
  663. self.line_color = kwargs.get('line_color', None)
  664. def __str__(self):
  665. return 'parametric cartesian line: (%s, %s) for %s over %s' % (
  666. str(self.expr_x), str(self.expr_y), str(self.var),
  667. str((self.start, self.end)))
  668. def get_parameter_points(self):
  669. np = import_module('numpy')
  670. return np.linspace(self.start, self.end, num=self.nb_of_points)
  671. def _uniform_sampling(self):
  672. param = self.get_parameter_points()
  673. fx = vectorized_lambdify([self.var], self.expr_x)
  674. fy = vectorized_lambdify([self.var], self.expr_y)
  675. list_x = fx(param)
  676. list_y = fy(param)
  677. return (list_x, list_y)
  678. def get_points(self):
  679. """ Return lists of coordinates for plotting. Depending on the
  680. `adaptive` option, this function will either use an adaptive algorithm
  681. or it will uniformly sample the expression over the provided range.
  682. Returns
  683. =======
  684. x: list
  685. List of x-coordinates
  686. y: list
  687. List of y-coordinates
  688. Explanation
  689. ===========
  690. The adaptive sampling is done by recursively checking if three
  691. points are almost collinear. If they are not collinear, then more
  692. points are added between those points.
  693. References
  694. ==========
  695. .. [1] Adaptive polygonal approximation of parametric curves,
  696. Luiz Henrique de Figueiredo.
  697. """
  698. if not self.adaptive:
  699. return self._uniform_sampling()
  700. f_x = lambdify([self.var], self.expr_x)
  701. f_y = lambdify([self.var], self.expr_y)
  702. x_coords = []
  703. y_coords = []
  704. def sample(param_p, param_q, p, q, depth):
  705. """ Samples recursively if three points are almost collinear.
  706. For depth < 6, points are added irrespective of whether they
  707. satisfy the collinearity condition or not. The maximum depth
  708. allowed is 12.
  709. """
  710. # Randomly sample to avoid aliasing.
  711. np = import_module('numpy')
  712. random = 0.45 + np.random.rand() * 0.1
  713. param_new = param_p + random * (param_q - param_p)
  714. xnew = f_x(param_new)
  715. ynew = f_y(param_new)
  716. new_point = np.array([xnew, ynew])
  717. # Maximum depth
  718. if depth > self.depth:
  719. x_coords.append(q[0])
  720. y_coords.append(q[1])
  721. # Sample irrespective of whether the line is flat till the
  722. # depth of 6. We are not using linspace to avoid aliasing.
  723. elif depth < 6:
  724. sample(param_p, param_new, p, new_point, depth + 1)
  725. sample(param_new, param_q, new_point, q, depth + 1)
  726. # Sample ten points if complex values are encountered
  727. # at both ends. If there is a real value in between, then
  728. # sample those points further.
  729. elif ((p[0] is None and q[1] is None) or
  730. (p[1] is None and q[1] is None)):
  731. param_array = np.linspace(param_p, param_q, 10)
  732. x_array = list(map(f_x, param_array))
  733. y_array = list(map(f_y, param_array))
  734. if not all(x is None and y is None
  735. for x, y in zip(x_array, y_array)):
  736. for i in range(len(y_array) - 1):
  737. if ((x_array[i] is not None and y_array[i] is not None) or
  738. (x_array[i + 1] is not None and y_array[i + 1] is not None)):
  739. point_a = [x_array[i], y_array[i]]
  740. point_b = [x_array[i + 1], y_array[i + 1]]
  741. sample(param_array[i], param_array[i], point_a,
  742. point_b, depth + 1)
  743. # Sample further if one of the end points in None (i.e. a complex
  744. # value) or the three points are not almost collinear.
  745. elif (p[0] is None or p[1] is None
  746. or q[1] is None or q[0] is None
  747. or not flat(p, new_point, q)):
  748. sample(param_p, param_new, p, new_point, depth + 1)
  749. sample(param_new, param_q, new_point, q, depth + 1)
  750. else:
  751. x_coords.append(q[0])
  752. y_coords.append(q[1])
  753. f_start_x = f_x(self.start)
  754. f_start_y = f_y(self.start)
  755. start = [f_start_x, f_start_y]
  756. f_end_x = f_x(self.end)
  757. f_end_y = f_y(self.end)
  758. end = [f_end_x, f_end_y]
  759. x_coords.append(f_start_x)
  760. y_coords.append(f_start_y)
  761. sample(self.start, self.end, start, end, 0)
  762. return x_coords, y_coords
  763. ### 3D lines
  764. class Line3DBaseSeries(Line2DBaseSeries):
  765. """A base class for 3D lines.
  766. Most of the stuff is derived from Line2DBaseSeries."""
  767. is_2Dline = False
  768. is_3Dline = True
  769. _dim = 3
  770. def __init__(self):
  771. super().__init__()
  772. class Parametric3DLineSeries(Line3DBaseSeries):
  773. """Representation for a 3D line consisting of three parametric SymPy
  774. expressions and a range."""
  775. is_parametric = True
  776. def __init__(self, expr_x, expr_y, expr_z, var_start_end, **kwargs):
  777. super().__init__()
  778. self.expr_x = sympify(expr_x)
  779. self.expr_y = sympify(expr_y)
  780. self.expr_z = sympify(expr_z)
  781. self.label = kwargs.get('label', None) or \
  782. "(%s, %s)" % (str(self.expr_x), str(self.expr_y))
  783. self.var = sympify(var_start_end[0])
  784. self.start = float(var_start_end[1])
  785. self.end = float(var_start_end[2])
  786. self.nb_of_points = kwargs.get('nb_of_points', 300)
  787. self.line_color = kwargs.get('line_color', None)
  788. self._xlim = None
  789. self._ylim = None
  790. self._zlim = None
  791. def __str__(self):
  792. return '3D parametric cartesian line: (%s, %s, %s) for %s over %s' % (
  793. str(self.expr_x), str(self.expr_y), str(self.expr_z),
  794. str(self.var), str((self.start, self.end)))
  795. def get_parameter_points(self):
  796. np = import_module('numpy')
  797. return np.linspace(self.start, self.end, num=self.nb_of_points)
  798. def get_points(self):
  799. np = import_module('numpy')
  800. param = self.get_parameter_points()
  801. fx = vectorized_lambdify([self.var], self.expr_x)
  802. fy = vectorized_lambdify([self.var], self.expr_y)
  803. fz = vectorized_lambdify([self.var], self.expr_z)
  804. list_x = fx(param)
  805. list_y = fy(param)
  806. list_z = fz(param)
  807. list_x = np.array(list_x, dtype=np.float64)
  808. list_y = np.array(list_y, dtype=np.float64)
  809. list_z = np.array(list_z, dtype=np.float64)
  810. list_x = np.ma.masked_invalid(list_x)
  811. list_y = np.ma.masked_invalid(list_y)
  812. list_z = np.ma.masked_invalid(list_z)
  813. self._xlim = (np.amin(list_x), np.amax(list_x))
  814. self._ylim = (np.amin(list_y), np.amax(list_y))
  815. self._zlim = (np.amin(list_z), np.amax(list_z))
  816. return list_x, list_y, list_z
  817. ### Surfaces
  818. class SurfaceBaseSeries(BaseSeries):
  819. """A base class for 3D surfaces."""
  820. is_3Dsurface = True
  821. def __init__(self):
  822. super().__init__()
  823. self.surface_color = None
  824. def get_color_array(self):
  825. np = import_module('numpy')
  826. c = self.surface_color
  827. if isinstance(c, Callable):
  828. f = np.vectorize(c)
  829. nargs = arity(c)
  830. if self.is_parametric:
  831. variables = list(map(centers_of_faces, self.get_parameter_meshes()))
  832. if nargs == 1:
  833. return f(variables[0])
  834. elif nargs == 2:
  835. return f(*variables)
  836. variables = list(map(centers_of_faces, self.get_meshes()))
  837. if nargs == 1:
  838. return f(variables[0])
  839. elif nargs == 2:
  840. return f(*variables[:2])
  841. else:
  842. return f(*variables)
  843. else:
  844. if isinstance(self, SurfaceOver2DRangeSeries):
  845. return c*np.ones(min(self.nb_of_points_x, self.nb_of_points_y))
  846. else:
  847. return c*np.ones(min(self.nb_of_points_u, self.nb_of_points_v))
  848. class SurfaceOver2DRangeSeries(SurfaceBaseSeries):
  849. """Representation for a 3D surface consisting of a SymPy expression and 2D
  850. range."""
  851. def __init__(self, expr, var_start_end_x, var_start_end_y, **kwargs):
  852. super().__init__()
  853. self.expr = sympify(expr)
  854. self.var_x = sympify(var_start_end_x[0])
  855. self.start_x = float(var_start_end_x[1])
  856. self.end_x = float(var_start_end_x[2])
  857. self.var_y = sympify(var_start_end_y[0])
  858. self.start_y = float(var_start_end_y[1])
  859. self.end_y = float(var_start_end_y[2])
  860. self.nb_of_points_x = kwargs.get('nb_of_points_x', 50)
  861. self.nb_of_points_y = kwargs.get('nb_of_points_y', 50)
  862. self.surface_color = kwargs.get('surface_color', None)
  863. self._xlim = (self.start_x, self.end_x)
  864. self._ylim = (self.start_y, self.end_y)
  865. def __str__(self):
  866. return ('cartesian surface: %s for'
  867. ' %s over %s and %s over %s') % (
  868. str(self.expr),
  869. str(self.var_x),
  870. str((self.start_x, self.end_x)),
  871. str(self.var_y),
  872. str((self.start_y, self.end_y)))
  873. def get_meshes(self):
  874. np = import_module('numpy')
  875. mesh_x, mesh_y = np.meshgrid(np.linspace(self.start_x, self.end_x,
  876. num=self.nb_of_points_x),
  877. np.linspace(self.start_y, self.end_y,
  878. num=self.nb_of_points_y))
  879. f = vectorized_lambdify((self.var_x, self.var_y), self.expr)
  880. mesh_z = f(mesh_x, mesh_y)
  881. mesh_z = np.array(mesh_z, dtype=np.float64)
  882. mesh_z = np.ma.masked_invalid(mesh_z)
  883. self._zlim = (np.amin(mesh_z), np.amax(mesh_z))
  884. return mesh_x, mesh_y, mesh_z
  885. class ParametricSurfaceSeries(SurfaceBaseSeries):
  886. """Representation for a 3D surface consisting of three parametric SymPy
  887. expressions and a range."""
  888. is_parametric = True
  889. def __init__(
  890. self, expr_x, expr_y, expr_z, var_start_end_u, var_start_end_v,
  891. **kwargs):
  892. super().__init__()
  893. self.expr_x = sympify(expr_x)
  894. self.expr_y = sympify(expr_y)
  895. self.expr_z = sympify(expr_z)
  896. self.var_u = sympify(var_start_end_u[0])
  897. self.start_u = float(var_start_end_u[1])
  898. self.end_u = float(var_start_end_u[2])
  899. self.var_v = sympify(var_start_end_v[0])
  900. self.start_v = float(var_start_end_v[1])
  901. self.end_v = float(var_start_end_v[2])
  902. self.nb_of_points_u = kwargs.get('nb_of_points_u', 50)
  903. self.nb_of_points_v = kwargs.get('nb_of_points_v', 50)
  904. self.surface_color = kwargs.get('surface_color', None)
  905. def __str__(self):
  906. return ('parametric cartesian surface: (%s, %s, %s) for'
  907. ' %s over %s and %s over %s') % (
  908. str(self.expr_x),
  909. str(self.expr_y),
  910. str(self.expr_z),
  911. str(self.var_u),
  912. str((self.start_u, self.end_u)),
  913. str(self.var_v),
  914. str((self.start_v, self.end_v)))
  915. def get_parameter_meshes(self):
  916. np = import_module('numpy')
  917. return np.meshgrid(np.linspace(self.start_u, self.end_u,
  918. num=self.nb_of_points_u),
  919. np.linspace(self.start_v, self.end_v,
  920. num=self.nb_of_points_v))
  921. def get_meshes(self):
  922. np = import_module('numpy')
  923. mesh_u, mesh_v = self.get_parameter_meshes()
  924. fx = vectorized_lambdify((self.var_u, self.var_v), self.expr_x)
  925. fy = vectorized_lambdify((self.var_u, self.var_v), self.expr_y)
  926. fz = vectorized_lambdify((self.var_u, self.var_v), self.expr_z)
  927. mesh_x = fx(mesh_u, mesh_v)
  928. mesh_y = fy(mesh_u, mesh_v)
  929. mesh_z = fz(mesh_u, mesh_v)
  930. mesh_x = np.array(mesh_x, dtype=np.float64)
  931. mesh_y = np.array(mesh_y, dtype=np.float64)
  932. mesh_z = np.array(mesh_z, dtype=np.float64)
  933. mesh_x = np.ma.masked_invalid(mesh_x)
  934. mesh_y = np.ma.masked_invalid(mesh_y)
  935. mesh_z = np.ma.masked_invalid(mesh_z)
  936. self._xlim = (np.amin(mesh_x), np.amax(mesh_x))
  937. self._ylim = (np.amin(mesh_y), np.amax(mesh_y))
  938. self._zlim = (np.amin(mesh_z), np.amax(mesh_z))
  939. return mesh_x, mesh_y, mesh_z
  940. ### Contours
  941. class ContourSeries(BaseSeries):
  942. """Representation for a contour plot."""
  943. # The code is mostly repetition of SurfaceOver2DRange.
  944. # Presently used in contour_plot function
  945. is_contour = True
  946. def __init__(self, expr, var_start_end_x, var_start_end_y):
  947. super().__init__()
  948. self.nb_of_points_x = 50
  949. self.nb_of_points_y = 50
  950. self.expr = sympify(expr)
  951. self.var_x = sympify(var_start_end_x[0])
  952. self.start_x = float(var_start_end_x[1])
  953. self.end_x = float(var_start_end_x[2])
  954. self.var_y = sympify(var_start_end_y[0])
  955. self.start_y = float(var_start_end_y[1])
  956. self.end_y = float(var_start_end_y[2])
  957. self.get_points = self.get_meshes
  958. self._xlim = (self.start_x, self.end_x)
  959. self._ylim = (self.start_y, self.end_y)
  960. def __str__(self):
  961. return ('contour: %s for '
  962. '%s over %s and %s over %s') % (
  963. str(self.expr),
  964. str(self.var_x),
  965. str((self.start_x, self.end_x)),
  966. str(self.var_y),
  967. str((self.start_y, self.end_y)))
  968. def get_meshes(self):
  969. np = import_module('numpy')
  970. mesh_x, mesh_y = np.meshgrid(np.linspace(self.start_x, self.end_x,
  971. num=self.nb_of_points_x),
  972. np.linspace(self.start_y, self.end_y,
  973. num=self.nb_of_points_y))
  974. f = vectorized_lambdify((self.var_x, self.var_y), self.expr)
  975. return (mesh_x, mesh_y, f(mesh_x, mesh_y))
  976. ##############################################################################
  977. # Backends
  978. ##############################################################################
  979. class BaseBackend:
  980. """Base class for all backends. A backend represents the plotting library,
  981. which implements the necessary functionalities in order to use SymPy
  982. plotting functions.
  983. How the plotting module works:
  984. 1. Whenever a plotting function is called, the provided expressions are
  985. processed and a list of instances of the `BaseSeries` class is created,
  986. containing the necessary information to plot the expressions (eg the
  987. expression, ranges, series name, ...). Eventually, these objects will
  988. generate the numerical data to be plotted.
  989. 2. A Plot object is instantiated, which stores the list of series and the
  990. main attributes of the plot (eg axis labels, title, ...).
  991. 3. When the "show" command is executed, a new backend is instantiated,
  992. which loops through each series object to generate and plot the
  993. numerical data. The backend is also going to set the axis labels, title,
  994. ..., according to the values stored in the Plot instance.
  995. The backend should check if it supports the data series that it's given
  996. (eg TextBackend supports only LineOver1DRange).
  997. It's the backend responsibility to know how to use the class of data series
  998. that it's given. Note that the current implementation of the `*Series`
  999. classes is "matplotlib-centric": the numerical data returned by the
  1000. `get_points` and `get_meshes` methods is meant to be used directly by
  1001. Matplotlib. Therefore, the new backend will have to pre-process the
  1002. numerical data to make it compatible with the chosen plotting library.
  1003. Keep in mind that future SymPy versions may improve the `*Series` classes in
  1004. order to return numerical data "non-matplotlib-centric", hence if you code
  1005. a new backend you have the responsibility to check if its working on each
  1006. SymPy release.
  1007. Please, explore the `MatplotlibBackend` source code to understand how a
  1008. backend should be coded.
  1009. Methods
  1010. =======
  1011. In order to be used by SymPy plotting functions, a backend must implement
  1012. the following methods:
  1013. * `show(self)`: used to loop over the data series, generate the numerical
  1014. data, plot it and set the axis labels, title, ...
  1015. * save(self, path): used to save the current plot to the specified file
  1016. path.
  1017. * close(self): used to close the current plot backend (note: some plotting
  1018. library doesn't support this functionality. In that case, just raise a
  1019. warning).
  1020. See also
  1021. ========
  1022. MatplotlibBackend
  1023. """
  1024. def __init__(self, parent):
  1025. super().__init__()
  1026. self.parent = parent
  1027. def show(self):
  1028. raise NotImplementedError
  1029. def save(self, path):
  1030. raise NotImplementedError
  1031. def close(self):
  1032. raise NotImplementedError
  1033. # Don't have to check for the success of importing matplotlib in each case;
  1034. # we will only be using this backend if we can successfully import matploblib
  1035. class MatplotlibBackend(BaseBackend):
  1036. """ This class implements the functionalities to use Matplotlib with SymPy
  1037. plotting functions.
  1038. """
  1039. def __init__(self, parent):
  1040. super().__init__(parent)
  1041. self.matplotlib = import_module('matplotlib',
  1042. import_kwargs={'fromlist': ['pyplot', 'cm', 'collections']},
  1043. min_module_version='1.1.0', catch=(RuntimeError,))
  1044. self.plt = self.matplotlib.pyplot
  1045. self.cm = self.matplotlib.cm
  1046. self.LineCollection = self.matplotlib.collections.LineCollection
  1047. aspect = getattr(self.parent, 'aspect_ratio', 'auto')
  1048. if aspect != 'auto':
  1049. aspect = float(aspect[1]) / aspect[0]
  1050. if isinstance(self.parent, Plot):
  1051. nrows, ncolumns = 1, 1
  1052. series_list = [self.parent._series]
  1053. elif isinstance(self.parent, PlotGrid):
  1054. nrows, ncolumns = self.parent.nrows, self.parent.ncolumns
  1055. series_list = self.parent._series
  1056. self.ax = []
  1057. self.fig = self.plt.figure(figsize=parent.size)
  1058. for i, series in enumerate(series_list):
  1059. are_3D = [s.is_3D for s in series]
  1060. if any(are_3D) and not all(are_3D):
  1061. raise ValueError('The matplotlib backend cannot mix 2D and 3D.')
  1062. elif all(are_3D):
  1063. # mpl_toolkits.mplot3d is necessary for
  1064. # projection='3d'
  1065. mpl_toolkits = import_module('mpl_toolkits', # noqa
  1066. import_kwargs={'fromlist': ['mplot3d']})
  1067. self.ax.append(self.fig.add_subplot(nrows, ncolumns, i + 1, projection='3d', aspect=aspect))
  1068. elif not any(are_3D):
  1069. self.ax.append(self.fig.add_subplot(nrows, ncolumns, i + 1, aspect=aspect))
  1070. self.ax[i].spines['left'].set_position('zero')
  1071. self.ax[i].spines['right'].set_color('none')
  1072. self.ax[i].spines['bottom'].set_position('zero')
  1073. self.ax[i].spines['top'].set_color('none')
  1074. self.ax[i].xaxis.set_ticks_position('bottom')
  1075. self.ax[i].yaxis.set_ticks_position('left')
  1076. @staticmethod
  1077. def get_segments(x, y, z=None):
  1078. """ Convert two list of coordinates to a list of segments to be used
  1079. with Matplotlib's LineCollection.
  1080. Parameters
  1081. ==========
  1082. x: list
  1083. List of x-coordinates
  1084. y: list
  1085. List of y-coordinates
  1086. z: list
  1087. List of z-coordinates for a 3D line.
  1088. """
  1089. np = import_module('numpy')
  1090. if z is not None:
  1091. dim = 3
  1092. points = (x, y, z)
  1093. else:
  1094. dim = 2
  1095. points = (x, y)
  1096. points = np.ma.array(points).T.reshape(-1, 1, dim)
  1097. return np.ma.concatenate([points[:-1], points[1:]], axis=1)
  1098. def _process_series(self, series, ax, parent):
  1099. np = import_module('numpy')
  1100. mpl_toolkits = import_module(
  1101. 'mpl_toolkits', import_kwargs={'fromlist': ['mplot3d']})
  1102. # XXX Workaround for matplotlib issue
  1103. # https://github.com/matplotlib/matplotlib/issues/17130
  1104. xlims, ylims, zlims = [], [], []
  1105. for s in series:
  1106. # Create the collections
  1107. if s.is_2Dline:
  1108. x, y = s.get_data()
  1109. if (isinstance(s.line_color, (int, float)) or
  1110. callable(s.line_color)):
  1111. segments = self.get_segments(x, y)
  1112. collection = self.LineCollection(segments)
  1113. collection.set_array(s.get_color_array())
  1114. ax.add_collection(collection)
  1115. else:
  1116. line, = ax.plot(x, y, label=s.label, color=s.line_color)
  1117. elif s.is_contour:
  1118. ax.contour(*s.get_meshes())
  1119. elif s.is_3Dline:
  1120. x, y, z = s.get_data()
  1121. if (isinstance(s.line_color, (int, float)) or
  1122. callable(s.line_color)):
  1123. art3d = mpl_toolkits.mplot3d.art3d
  1124. segments = self.get_segments(x, y, z)
  1125. collection = art3d.Line3DCollection(segments)
  1126. collection.set_array(s.get_color_array())
  1127. ax.add_collection(collection)
  1128. else:
  1129. ax.plot(x, y, z, label=s.label,
  1130. color=s.line_color)
  1131. xlims.append(s._xlim)
  1132. ylims.append(s._ylim)
  1133. zlims.append(s._zlim)
  1134. elif s.is_3Dsurface:
  1135. x, y, z = s.get_meshes()
  1136. collection = ax.plot_surface(x, y, z,
  1137. cmap=getattr(self.cm, 'viridis', self.cm.jet),
  1138. rstride=1, cstride=1, linewidth=0.1)
  1139. if isinstance(s.surface_color, (float, int, Callable)):
  1140. color_array = s.get_color_array()
  1141. color_array = color_array.reshape(color_array.size)
  1142. collection.set_array(color_array)
  1143. else:
  1144. collection.set_color(s.surface_color)
  1145. xlims.append(s._xlim)
  1146. ylims.append(s._ylim)
  1147. zlims.append(s._zlim)
  1148. elif s.is_implicit:
  1149. points = s.get_raster()
  1150. if len(points) == 2:
  1151. # interval math plotting
  1152. x, y = _matplotlib_list(points[0])
  1153. ax.fill(x, y, facecolor=s.line_color, edgecolor='None')
  1154. else:
  1155. # use contourf or contour depending on whether it is
  1156. # an inequality or equality.
  1157. # XXX: ``contour`` plots multiple lines. Should be fixed.
  1158. ListedColormap = self.matplotlib.colors.ListedColormap
  1159. colormap = ListedColormap(["white", s.line_color])
  1160. xarray, yarray, zarray, plot_type = points
  1161. if plot_type == 'contour':
  1162. ax.contour(xarray, yarray, zarray, cmap=colormap)
  1163. else:
  1164. ax.contourf(xarray, yarray, zarray, cmap=colormap)
  1165. else:
  1166. raise NotImplementedError(
  1167. '{} is not supported in the SymPy plotting module '
  1168. 'with matplotlib backend. Please report this issue.'
  1169. .format(ax))
  1170. Axes3D = mpl_toolkits.mplot3d.Axes3D
  1171. if not isinstance(ax, Axes3D):
  1172. ax.autoscale_view(
  1173. scalex=ax.get_autoscalex_on(),
  1174. scaley=ax.get_autoscaley_on())
  1175. else:
  1176. # XXX Workaround for matplotlib issue
  1177. # https://github.com/matplotlib/matplotlib/issues/17130
  1178. if xlims:
  1179. xlims = np.array(xlims)
  1180. xlim = (np.amin(xlims[:, 0]), np.amax(xlims[:, 1]))
  1181. ax.set_xlim(xlim)
  1182. else:
  1183. ax.set_xlim([0, 1])
  1184. if ylims:
  1185. ylims = np.array(ylims)
  1186. ylim = (np.amin(ylims[:, 0]), np.amax(ylims[:, 1]))
  1187. ax.set_ylim(ylim)
  1188. else:
  1189. ax.set_ylim([0, 1])
  1190. if zlims:
  1191. zlims = np.array(zlims)
  1192. zlim = (np.amin(zlims[:, 0]), np.amax(zlims[:, 1]))
  1193. ax.set_zlim(zlim)
  1194. else:
  1195. ax.set_zlim([0, 1])
  1196. # Set global options.
  1197. # TODO The 3D stuff
  1198. # XXX The order of those is important.
  1199. if parent.xscale and not isinstance(ax, Axes3D):
  1200. ax.set_xscale(parent.xscale)
  1201. if parent.yscale and not isinstance(ax, Axes3D):
  1202. ax.set_yscale(parent.yscale)
  1203. if not isinstance(ax, Axes3D) or self.matplotlib.__version__ >= '1.2.0': # XXX in the distant future remove this check
  1204. ax.set_autoscale_on(parent.autoscale)
  1205. if parent.axis_center:
  1206. val = parent.axis_center
  1207. if isinstance(ax, Axes3D):
  1208. pass
  1209. elif val == 'center':
  1210. ax.spines['left'].set_position('center')
  1211. ax.spines['bottom'].set_position('center')
  1212. elif val == 'auto':
  1213. xl, xh = ax.get_xlim()
  1214. yl, yh = ax.get_ylim()
  1215. pos_left = ('data', 0) if xl*xh <= 0 else 'center'
  1216. pos_bottom = ('data', 0) if yl*yh <= 0 else 'center'
  1217. ax.spines['left'].set_position(pos_left)
  1218. ax.spines['bottom'].set_position(pos_bottom)
  1219. else:
  1220. ax.spines['left'].set_position(('data', val[0]))
  1221. ax.spines['bottom'].set_position(('data', val[1]))
  1222. if not parent.axis:
  1223. ax.set_axis_off()
  1224. if parent.legend:
  1225. if ax.legend():
  1226. ax.legend_.set_visible(parent.legend)
  1227. if parent.margin:
  1228. ax.set_xmargin(parent.margin)
  1229. ax.set_ymargin(parent.margin)
  1230. if parent.title:
  1231. ax.set_title(parent.title)
  1232. if parent.xlabel:
  1233. ax.set_xlabel(parent.xlabel, position=(1, 0))
  1234. if parent.ylabel:
  1235. ax.set_ylabel(parent.ylabel, position=(0, 1))
  1236. if isinstance(ax, Axes3D) and parent.zlabel:
  1237. ax.set_zlabel(parent.zlabel, position=(0, 1))
  1238. if parent.annotations:
  1239. for a in parent.annotations:
  1240. ax.annotate(**a)
  1241. if parent.markers:
  1242. for marker in parent.markers:
  1243. # make a copy of the marker dictionary
  1244. # so that it doesn't get altered
  1245. m = marker.copy()
  1246. args = m.pop('args')
  1247. ax.plot(*args, **m)
  1248. if parent.rectangles:
  1249. for r in parent.rectangles:
  1250. rect = self.matplotlib.patches.Rectangle(**r)
  1251. ax.add_patch(rect)
  1252. if parent.fill:
  1253. ax.fill_between(**parent.fill)
  1254. # xlim and ylim shoulld always be set at last so that plot limits
  1255. # doesn't get altered during the process.
  1256. if parent.xlim:
  1257. ax.set_xlim(parent.xlim)
  1258. if parent.ylim:
  1259. ax.set_ylim(parent.ylim)
  1260. def process_series(self):
  1261. """
  1262. Iterates over every ``Plot`` object and further calls
  1263. _process_series()
  1264. """
  1265. parent = self.parent
  1266. if isinstance(parent, Plot):
  1267. series_list = [parent._series]
  1268. else:
  1269. series_list = parent._series
  1270. for i, (series, ax) in enumerate(zip(series_list, self.ax)):
  1271. if isinstance(self.parent, PlotGrid):
  1272. parent = self.parent.args[i]
  1273. self._process_series(series, ax, parent)
  1274. def show(self):
  1275. self.process_series()
  1276. #TODO after fixing https://github.com/ipython/ipython/issues/1255
  1277. # you can uncomment the next line and remove the pyplot.show() call
  1278. #self.fig.show()
  1279. if _show:
  1280. self.fig.tight_layout()
  1281. self.plt.show()
  1282. else:
  1283. self.close()
  1284. def save(self, path):
  1285. self.process_series()
  1286. self.fig.savefig(path)
  1287. def close(self):
  1288. self.plt.close(self.fig)
  1289. class TextBackend(BaseBackend):
  1290. def __init__(self, parent):
  1291. super().__init__(parent)
  1292. def show(self):
  1293. if not _show:
  1294. return
  1295. if len(self.parent._series) != 1:
  1296. raise ValueError(
  1297. 'The TextBackend supports only one graph per Plot.')
  1298. elif not isinstance(self.parent._series[0], LineOver1DRangeSeries):
  1299. raise ValueError(
  1300. 'The TextBackend supports only expressions over a 1D range')
  1301. else:
  1302. ser = self.parent._series[0]
  1303. textplot(ser.expr, ser.start, ser.end)
  1304. def close(self):
  1305. pass
  1306. class DefaultBackend(BaseBackend):
  1307. def __new__(cls, parent):
  1308. matplotlib = import_module('matplotlib', min_module_version='1.1.0', catch=(RuntimeError,))
  1309. if matplotlib:
  1310. return MatplotlibBackend(parent)
  1311. else:
  1312. return TextBackend(parent)
  1313. plot_backends = {
  1314. 'matplotlib': MatplotlibBackend,
  1315. 'text': TextBackend,
  1316. 'default': DefaultBackend
  1317. }
  1318. ##############################################################################
  1319. # Finding the centers of line segments or mesh faces
  1320. ##############################################################################
  1321. def centers_of_segments(array):
  1322. np = import_module('numpy')
  1323. return np.mean(np.vstack((array[:-1], array[1:])), 0)
  1324. def centers_of_faces(array):
  1325. np = import_module('numpy')
  1326. return np.mean(np.dstack((array[:-1, :-1],
  1327. array[1:, :-1],
  1328. array[:-1, 1:],
  1329. array[:-1, :-1],
  1330. )), 2)
  1331. def flat(x, y, z, eps=1e-3):
  1332. """Checks whether three points are almost collinear"""
  1333. np = import_module('numpy')
  1334. # Workaround plotting piecewise (#8577):
  1335. # workaround for `lambdify` in `.experimental_lambdify` fails
  1336. # to return numerical values in some cases. Lower-level fix
  1337. # in `lambdify` is possible.
  1338. vector_a = (x - y).astype(np.float64)
  1339. vector_b = (z - y).astype(np.float64)
  1340. dot_product = np.dot(vector_a, vector_b)
  1341. vector_a_norm = np.linalg.norm(vector_a)
  1342. vector_b_norm = np.linalg.norm(vector_b)
  1343. cos_theta = dot_product / (vector_a_norm * vector_b_norm)
  1344. return abs(cos_theta + 1) < eps
  1345. def _matplotlib_list(interval_list):
  1346. """
  1347. Returns lists for matplotlib ``fill`` command from a list of bounding
  1348. rectangular intervals
  1349. """
  1350. xlist = []
  1351. ylist = []
  1352. if len(interval_list):
  1353. for intervals in interval_list:
  1354. intervalx = intervals[0]
  1355. intervaly = intervals[1]
  1356. xlist.extend([intervalx.start, intervalx.start,
  1357. intervalx.end, intervalx.end, None])
  1358. ylist.extend([intervaly.start, intervaly.end,
  1359. intervaly.end, intervaly.start, None])
  1360. else:
  1361. #XXX Ugly hack. Matplotlib does not accept empty lists for ``fill``
  1362. xlist.extend((None, None, None, None))
  1363. ylist.extend((None, None, None, None))
  1364. return xlist, ylist
  1365. ####New API for plotting module ####
  1366. # TODO: Add color arrays for plots.
  1367. # TODO: Add more plotting options for 3d plots.
  1368. # TODO: Adaptive sampling for 3D plots.
  1369. def plot(*args, show=True, **kwargs):
  1370. """Plots a function of a single variable as a curve.
  1371. Parameters
  1372. ==========
  1373. args :
  1374. The first argument is the expression representing the function
  1375. of single variable to be plotted.
  1376. The last argument is a 3-tuple denoting the range of the free
  1377. variable. e.g. ``(x, 0, 5)``
  1378. Typical usage examples are in the followings:
  1379. - Plotting a single expression with a single range.
  1380. ``plot(expr, range, **kwargs)``
  1381. - Plotting a single expression with the default range (-10, 10).
  1382. ``plot(expr, **kwargs)``
  1383. - Plotting multiple expressions with a single range.
  1384. ``plot(expr1, expr2, ..., range, **kwargs)``
  1385. - Plotting multiple expressions with multiple ranges.
  1386. ``plot((expr1, range1), (expr2, range2), ..., **kwargs)``
  1387. It is best practice to specify range explicitly because default
  1388. range may change in the future if a more advanced default range
  1389. detection algorithm is implemented.
  1390. show : bool, optional
  1391. The default value is set to ``True``. Set show to ``False`` and
  1392. the function will not display the plot. The returned instance of
  1393. the ``Plot`` class can then be used to save or display the plot
  1394. by calling the ``save()`` and ``show()`` methods respectively.
  1395. line_color : string, or float, or function, optional
  1396. Specifies the color for the plot.
  1397. See ``Plot`` to see how to set color for the plots.
  1398. Note that by setting ``line_color``, it would be applied simultaneously
  1399. to all the series.
  1400. title : str, optional
  1401. Title of the plot. It is set to the latex representation of
  1402. the expression, if the plot has only one expression.
  1403. label : str, optional
  1404. The label of the expression in the plot. It will be used when
  1405. called with ``legend``. Default is the name of the expression.
  1406. e.g. ``sin(x)``
  1407. xlabel : str, optional
  1408. Label for the x-axis.
  1409. ylabel : str, optional
  1410. Label for the y-axis.
  1411. xscale : 'linear' or 'log', optional
  1412. Sets the scaling of the x-axis.
  1413. yscale : 'linear' or 'log', optional
  1414. Sets the scaling of the y-axis.
  1415. axis_center : (float, float), optional
  1416. Tuple of two floats denoting the coordinates of the center or
  1417. {'center', 'auto'}
  1418. xlim : (float, float), optional
  1419. Denotes the x-axis limits, ``(min, max)```.
  1420. ylim : (float, float), optional
  1421. Denotes the y-axis limits, ``(min, max)```.
  1422. annotations : list, optional
  1423. A list of dictionaries specifying the type of annotation
  1424. required. The keys in the dictionary should be equivalent
  1425. to the arguments of the matplotlib's annotate() function.
  1426. markers : list, optional
  1427. A list of dictionaries specifying the type the markers required.
  1428. The keys in the dictionary should be equivalent to the arguments
  1429. of the matplotlib's plot() function along with the marker
  1430. related keyworded arguments.
  1431. rectangles : list, optional
  1432. A list of dictionaries specifying the dimensions of the
  1433. rectangles to be plotted. The keys in the dictionary should be
  1434. equivalent to the arguments of the matplotlib's
  1435. patches.Rectangle class.
  1436. fill : dict, optional
  1437. A dictionary specifying the type of color filling required in
  1438. the plot. The keys in the dictionary should be equivalent to the
  1439. arguments of the matplotlib's fill_between() function.
  1440. adaptive : bool, optional
  1441. The default value is set to ``True``. Set adaptive to ``False``
  1442. and specify ``nb_of_points`` if uniform sampling is required.
  1443. The plotting uses an adaptive algorithm which samples
  1444. recursively to accurately plot. The adaptive algorithm uses a
  1445. random point near the midpoint of two points that has to be
  1446. further sampled. Hence the same plots can appear slightly
  1447. different.
  1448. depth : int, optional
  1449. Recursion depth of the adaptive algorithm. A depth of value
  1450. ``n`` samples a maximum of `2^{n}` points.
  1451. If the ``adaptive`` flag is set to ``False``, this will be
  1452. ignored.
  1453. nb_of_points : int, optional
  1454. Used when the ``adaptive`` is set to ``False``. The function
  1455. is uniformly sampled at ``nb_of_points`` number of points.
  1456. If the ``adaptive`` flag is set to ``True``, this will be
  1457. ignored.
  1458. size : (float, float), optional
  1459. A tuple in the form (width, height) in inches to specify the size of
  1460. the overall figure. The default value is set to ``None``, meaning
  1461. the size will be set by the default backend.
  1462. Examples
  1463. ========
  1464. .. plot::
  1465. :context: close-figs
  1466. :format: doctest
  1467. :include-source: True
  1468. >>> from sympy import symbols
  1469. >>> from sympy.plotting import plot
  1470. >>> x = symbols('x')
  1471. Single Plot
  1472. .. plot::
  1473. :context: close-figs
  1474. :format: doctest
  1475. :include-source: True
  1476. >>> plot(x**2, (x, -5, 5))
  1477. Plot object containing:
  1478. [0]: cartesian line: x**2 for x over (-5.0, 5.0)
  1479. Multiple plots with single range.
  1480. .. plot::
  1481. :context: close-figs
  1482. :format: doctest
  1483. :include-source: True
  1484. >>> plot(x, x**2, x**3, (x, -5, 5))
  1485. Plot object containing:
  1486. [0]: cartesian line: x for x over (-5.0, 5.0)
  1487. [1]: cartesian line: x**2 for x over (-5.0, 5.0)
  1488. [2]: cartesian line: x**3 for x over (-5.0, 5.0)
  1489. Multiple plots with different ranges.
  1490. .. plot::
  1491. :context: close-figs
  1492. :format: doctest
  1493. :include-source: True
  1494. >>> plot((x**2, (x, -6, 6)), (x, (x, -5, 5)))
  1495. Plot object containing:
  1496. [0]: cartesian line: x**2 for x over (-6.0, 6.0)
  1497. [1]: cartesian line: x for x over (-5.0, 5.0)
  1498. No adaptive sampling.
  1499. .. plot::
  1500. :context: close-figs
  1501. :format: doctest
  1502. :include-source: True
  1503. >>> plot(x**2, adaptive=False, nb_of_points=400)
  1504. Plot object containing:
  1505. [0]: cartesian line: x**2 for x over (-10.0, 10.0)
  1506. See Also
  1507. ========
  1508. Plot, LineOver1DRangeSeries
  1509. """
  1510. args = list(map(sympify, args))
  1511. free = set()
  1512. for a in args:
  1513. if isinstance(a, Expr):
  1514. free |= a.free_symbols
  1515. if len(free) > 1:
  1516. raise ValueError(
  1517. 'The same variable should be used in all '
  1518. 'univariate expressions being plotted.')
  1519. x = free.pop() if free else Symbol('x')
  1520. kwargs.setdefault('xlabel', x.name)
  1521. kwargs.setdefault('ylabel', 'f(%s)' % x.name)
  1522. series = []
  1523. plot_expr = check_arguments(args, 1, 1)
  1524. series = [LineOver1DRangeSeries(*arg, **kwargs) for arg in plot_expr]
  1525. plots = Plot(*series, **kwargs)
  1526. if show:
  1527. plots.show()
  1528. return plots
  1529. def plot_parametric(*args, show=True, **kwargs):
  1530. """
  1531. Plots a 2D parametric curve.
  1532. Parameters
  1533. ==========
  1534. args
  1535. Common specifications are:
  1536. - Plotting a single parametric curve with a range
  1537. ``plot_parametric((expr_x, expr_y), range)``
  1538. - Plotting multiple parametric curves with the same range
  1539. ``plot_parametric((expr_x, expr_y), ..., range)``
  1540. - Plotting multiple parametric curves with different ranges
  1541. ``plot_parametric((expr_x, expr_y, range), ...)``
  1542. ``expr_x`` is the expression representing $x$ component of the
  1543. parametric function.
  1544. ``expr_y`` is the expression representing $y$ component of the
  1545. parametric function.
  1546. ``range`` is a 3-tuple denoting the parameter symbol, start and
  1547. stop. For example, ``(u, 0, 5)``.
  1548. If the range is not specified, then a default range of (-10, 10)
  1549. is used.
  1550. However, if the arguments are specified as
  1551. ``(expr_x, expr_y, range), ...``, you must specify the ranges
  1552. for each expressions manually.
  1553. Default range may change in the future if a more advanced
  1554. algorithm is implemented.
  1555. adaptive : bool, optional
  1556. Specifies whether to use the adaptive sampling or not.
  1557. The default value is set to ``True``. Set adaptive to ``False``
  1558. and specify ``nb_of_points`` if uniform sampling is required.
  1559. depth : int, optional
  1560. The recursion depth of the adaptive algorithm. A depth of
  1561. value $n$ samples a maximum of $2^n$ points.
  1562. nb_of_points : int, optional
  1563. Used when the ``adaptive`` flag is set to ``False``.
  1564. Specifies the number of the points used for the uniform
  1565. sampling.
  1566. line_color : string, or float, or function, optional
  1567. Specifies the color for the plot.
  1568. See ``Plot`` to see how to set color for the plots.
  1569. Note that by setting ``line_color``, it would be applied simultaneously
  1570. to all the series.
  1571. label : str, optional
  1572. The label of the expression in the plot. It will be used when
  1573. called with ``legend``. Default is the name of the expression.
  1574. e.g. ``sin(x)``
  1575. xlabel : str, optional
  1576. Label for the x-axis.
  1577. ylabel : str, optional
  1578. Label for the y-axis.
  1579. xscale : 'linear' or 'log', optional
  1580. Sets the scaling of the x-axis.
  1581. yscale : 'linear' or 'log', optional
  1582. Sets the scaling of the y-axis.
  1583. axis_center : (float, float), optional
  1584. Tuple of two floats denoting the coordinates of the center or
  1585. {'center', 'auto'}
  1586. xlim : (float, float), optional
  1587. Denotes the x-axis limits, ``(min, max)```.
  1588. ylim : (float, float), optional
  1589. Denotes the y-axis limits, ``(min, max)```.
  1590. size : (float, float), optional
  1591. A tuple in the form (width, height) in inches to specify the size of
  1592. the overall figure. The default value is set to ``None``, meaning
  1593. the size will be set by the default backend.
  1594. Examples
  1595. ========
  1596. .. plot::
  1597. :context: reset
  1598. :format: doctest
  1599. :include-source: True
  1600. >>> from sympy import symbols, cos, sin
  1601. >>> from sympy.plotting import plot_parametric
  1602. >>> u = symbols('u')
  1603. A parametric plot with a single expression:
  1604. .. plot::
  1605. :context: close-figs
  1606. :format: doctest
  1607. :include-source: True
  1608. >>> plot_parametric((cos(u), sin(u)), (u, -5, 5))
  1609. Plot object containing:
  1610. [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0)
  1611. A parametric plot with multiple expressions with the same range:
  1612. .. plot::
  1613. :context: close-figs
  1614. :format: doctest
  1615. :include-source: True
  1616. >>> plot_parametric((cos(u), sin(u)), (u, cos(u)), (u, -10, 10))
  1617. Plot object containing:
  1618. [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-10.0, 10.0)
  1619. [1]: parametric cartesian line: (u, cos(u)) for u over (-10.0, 10.0)
  1620. A parametric plot with multiple expressions with different ranges
  1621. for each curve:
  1622. .. plot::
  1623. :context: close-figs
  1624. :format: doctest
  1625. :include-source: True
  1626. >>> plot_parametric((cos(u), sin(u), (u, -5, 5)),
  1627. ... (cos(u), u, (u, -5, 5)))
  1628. Plot object containing:
  1629. [0]: parametric cartesian line: (cos(u), sin(u)) for u over (-5.0, 5.0)
  1630. [1]: parametric cartesian line: (cos(u), u) for u over (-5.0, 5.0)
  1631. Notes
  1632. =====
  1633. The plotting uses an adaptive algorithm which samples recursively to
  1634. accurately plot the curve. The adaptive algorithm uses a random point
  1635. near the midpoint of two points that has to be further sampled.
  1636. Hence, repeating the same plot command can give slightly different
  1637. results because of the random sampling.
  1638. If there are multiple plots, then the same optional arguments are
  1639. applied to all the plots drawn in the same canvas. If you want to
  1640. set these options separately, you can index the returned ``Plot``
  1641. object and set it.
  1642. For example, when you specify ``line_color`` once, it would be
  1643. applied simultaneously to both series.
  1644. .. plot::
  1645. :context: close-figs
  1646. :format: doctest
  1647. :include-source: True
  1648. >>> from sympy import pi
  1649. >>> expr1 = (u, cos(2*pi*u)/2 + 1/2)
  1650. >>> expr2 = (u, sin(2*pi*u)/2 + 1/2)
  1651. >>> p = plot_parametric(expr1, expr2, (u, 0, 1), line_color='blue')
  1652. If you want to specify the line color for the specific series, you
  1653. should index each item and apply the property manually.
  1654. .. plot::
  1655. :context: close-figs
  1656. :format: doctest
  1657. :include-source: True
  1658. >>> p[0].line_color = 'red'
  1659. >>> p.show()
  1660. See Also
  1661. ========
  1662. Plot, Parametric2DLineSeries
  1663. """
  1664. args = list(map(sympify, args))
  1665. series = []
  1666. plot_expr = check_arguments(args, 2, 1)
  1667. series = [Parametric2DLineSeries(*arg, **kwargs) for arg in plot_expr]
  1668. plots = Plot(*series, **kwargs)
  1669. if show:
  1670. plots.show()
  1671. return plots
  1672. def plot3d_parametric_line(*args, show=True, **kwargs):
  1673. """
  1674. Plots a 3D parametric line plot.
  1675. Usage
  1676. =====
  1677. Single plot:
  1678. ``plot3d_parametric_line(expr_x, expr_y, expr_z, range, **kwargs)``
  1679. If the range is not specified, then a default range of (-10, 10) is used.
  1680. Multiple plots.
  1681. ``plot3d_parametric_line((expr_x, expr_y, expr_z, range), ..., **kwargs)``
  1682. Ranges have to be specified for every expression.
  1683. Default range may change in the future if a more advanced default range
  1684. detection algorithm is implemented.
  1685. Arguments
  1686. =========
  1687. ``expr_x`` : Expression representing the function along x.
  1688. ``expr_y`` : Expression representing the function along y.
  1689. ``expr_z`` : Expression representing the function along z.
  1690. ``range``: ``(u, 0, 5)``, A 3-tuple denoting the range of the parameter
  1691. variable.
  1692. Keyword Arguments
  1693. =================
  1694. Arguments for ``Parametric3DLineSeries`` class.
  1695. ``nb_of_points``: The range is uniformly sampled at ``nb_of_points``
  1696. number of points.
  1697. Aesthetics:
  1698. ``line_color``: string, or float, or function, optional
  1699. Specifies the color for the plot.
  1700. See ``Plot`` to see how to set color for the plots.
  1701. Note that by setting ``line_color``, it would be applied simultaneously
  1702. to all the series.
  1703. ``label``: str
  1704. The label to the plot. It will be used when called with ``legend=True``
  1705. to denote the function with the given label in the plot.
  1706. If there are multiple plots, then the same series arguments are applied to
  1707. all the plots. If you want to set these options separately, you can index
  1708. the returned ``Plot`` object and set it.
  1709. Arguments for ``Plot`` class.
  1710. ``title`` : str. Title of the plot.
  1711. ``size`` : (float, float), optional
  1712. A tuple in the form (width, height) in inches to specify the size of
  1713. the overall figure. The default value is set to ``None``, meaning
  1714. the size will be set by the default backend.
  1715. Examples
  1716. ========
  1717. .. plot::
  1718. :context: reset
  1719. :format: doctest
  1720. :include-source: True
  1721. >>> from sympy import symbols, cos, sin
  1722. >>> from sympy.plotting import plot3d_parametric_line
  1723. >>> u = symbols('u')
  1724. Single plot.
  1725. .. plot::
  1726. :context: close-figs
  1727. :format: doctest
  1728. :include-source: True
  1729. >>> plot3d_parametric_line(cos(u), sin(u), u, (u, -5, 5))
  1730. Plot object containing:
  1731. [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0)
  1732. Multiple plots.
  1733. .. plot::
  1734. :context: close-figs
  1735. :format: doctest
  1736. :include-source: True
  1737. >>> plot3d_parametric_line((cos(u), sin(u), u, (u, -5, 5)),
  1738. ... (sin(u), u**2, u, (u, -5, 5)))
  1739. Plot object containing:
  1740. [0]: 3D parametric cartesian line: (cos(u), sin(u), u) for u over (-5.0, 5.0)
  1741. [1]: 3D parametric cartesian line: (sin(u), u**2, u) for u over (-5.0, 5.0)
  1742. See Also
  1743. ========
  1744. Plot, Parametric3DLineSeries
  1745. """
  1746. args = list(map(sympify, args))
  1747. series = []
  1748. plot_expr = check_arguments(args, 3, 1)
  1749. series = [Parametric3DLineSeries(*arg, **kwargs) for arg in plot_expr]
  1750. kwargs.setdefault("xlabel", "x")
  1751. kwargs.setdefault("ylabel", "y")
  1752. kwargs.setdefault("zlabel", "z")
  1753. plots = Plot(*series, **kwargs)
  1754. if show:
  1755. plots.show()
  1756. return plots
  1757. def plot3d(*args, show=True, **kwargs):
  1758. """
  1759. Plots a 3D surface plot.
  1760. Usage
  1761. =====
  1762. Single plot
  1763. ``plot3d(expr, range_x, range_y, **kwargs)``
  1764. If the ranges are not specified, then a default range of (-10, 10) is used.
  1765. Multiple plot with the same range.
  1766. ``plot3d(expr1, expr2, range_x, range_y, **kwargs)``
  1767. If the ranges are not specified, then a default range of (-10, 10) is used.
  1768. Multiple plots with different ranges.
  1769. ``plot3d((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)``
  1770. Ranges have to be specified for every expression.
  1771. Default range may change in the future if a more advanced default range
  1772. detection algorithm is implemented.
  1773. Arguments
  1774. =========
  1775. ``expr`` : Expression representing the function along x.
  1776. ``range_x``: (x, 0, 5), A 3-tuple denoting the range of the x
  1777. variable.
  1778. ``range_y``: (y, 0, 5), A 3-tuple denoting the range of the y
  1779. variable.
  1780. Keyword Arguments
  1781. =================
  1782. Arguments for ``SurfaceOver2DRangeSeries`` class:
  1783. ``nb_of_points_x``: int. The x range is sampled uniformly at
  1784. ``nb_of_points_x`` of points.
  1785. ``nb_of_points_y``: int. The y range is sampled uniformly at
  1786. ``nb_of_points_y`` of points.
  1787. Aesthetics:
  1788. ``surface_color``: Function which returns a float. Specifies the color for
  1789. the surface of the plot. See ``sympy.plotting.Plot`` for more details.
  1790. If there are multiple plots, then the same series arguments are applied to
  1791. all the plots. If you want to set these options separately, you can index
  1792. the returned ``Plot`` object and set it.
  1793. Arguments for ``Plot`` class:
  1794. ``title`` : str. Title of the plot.
  1795. ``size`` : (float, float), optional
  1796. A tuple in the form (width, height) in inches to specify the size of the
  1797. overall figure. The default value is set to ``None``, meaning the size will
  1798. be set by the default backend.
  1799. Examples
  1800. ========
  1801. .. plot::
  1802. :context: reset
  1803. :format: doctest
  1804. :include-source: True
  1805. >>> from sympy import symbols
  1806. >>> from sympy.plotting import plot3d
  1807. >>> x, y = symbols('x y')
  1808. Single plot
  1809. .. plot::
  1810. :context: close-figs
  1811. :format: doctest
  1812. :include-source: True
  1813. >>> plot3d(x*y, (x, -5, 5), (y, -5, 5))
  1814. Plot object containing:
  1815. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  1816. Multiple plots with same range
  1817. .. plot::
  1818. :context: close-figs
  1819. :format: doctest
  1820. :include-source: True
  1821. >>> plot3d(x*y, -x*y, (x, -5, 5), (y, -5, 5))
  1822. Plot object containing:
  1823. [0]: cartesian surface: x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  1824. [1]: cartesian surface: -x*y for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  1825. Multiple plots with different ranges.
  1826. .. plot::
  1827. :context: close-figs
  1828. :format: doctest
  1829. :include-source: True
  1830. >>> plot3d((x**2 + y**2, (x, -5, 5), (y, -5, 5)),
  1831. ... (x*y, (x, -3, 3), (y, -3, 3)))
  1832. Plot object containing:
  1833. [0]: cartesian surface: x**2 + y**2 for x over (-5.0, 5.0) and y over (-5.0, 5.0)
  1834. [1]: cartesian surface: x*y for x over (-3.0, 3.0) and y over (-3.0, 3.0)
  1835. See Also
  1836. ========
  1837. Plot, SurfaceOver2DRangeSeries
  1838. """
  1839. args = list(map(sympify, args))
  1840. series = []
  1841. plot_expr = check_arguments(args, 1, 2)
  1842. series = [SurfaceOver2DRangeSeries(*arg, **kwargs) for arg in plot_expr]
  1843. xlabel = series[0].var_x.name
  1844. ylabel = series[0].var_y.name
  1845. kwargs.setdefault("xlabel", xlabel)
  1846. kwargs.setdefault("ylabel", ylabel)
  1847. kwargs.setdefault("zlabel", "f(%s, %s)" % (xlabel, ylabel))
  1848. plots = Plot(*series, **kwargs)
  1849. if show:
  1850. plots.show()
  1851. return plots
  1852. def plot3d_parametric_surface(*args, show=True, **kwargs):
  1853. """
  1854. Plots a 3D parametric surface plot.
  1855. Explanation
  1856. ===========
  1857. Single plot.
  1858. ``plot3d_parametric_surface(expr_x, expr_y, expr_z, range_u, range_v, **kwargs)``
  1859. If the ranges is not specified, then a default range of (-10, 10) is used.
  1860. Multiple plots.
  1861. ``plot3d_parametric_surface((expr_x, expr_y, expr_z, range_u, range_v), ..., **kwargs)``
  1862. Ranges have to be specified for every expression.
  1863. Default range may change in the future if a more advanced default range
  1864. detection algorithm is implemented.
  1865. Arguments
  1866. =========
  1867. ``expr_x``: Expression representing the function along ``x``.
  1868. ``expr_y``: Expression representing the function along ``y``.
  1869. ``expr_z``: Expression representing the function along ``z``.
  1870. ``range_u``: ``(u, 0, 5)``, A 3-tuple denoting the range of the ``u``
  1871. variable.
  1872. ``range_v``: ``(v, 0, 5)``, A 3-tuple denoting the range of the v
  1873. variable.
  1874. Keyword Arguments
  1875. =================
  1876. Arguments for ``ParametricSurfaceSeries`` class:
  1877. ``nb_of_points_u``: int. The ``u`` range is sampled uniformly at
  1878. ``nb_of_points_v`` of points
  1879. ``nb_of_points_y``: int. The ``v`` range is sampled uniformly at
  1880. ``nb_of_points_y`` of points
  1881. Aesthetics:
  1882. ``surface_color``: Function which returns a float. Specifies the color for
  1883. the surface of the plot. See ``sympy.plotting.Plot`` for more details.
  1884. If there are multiple plots, then the same series arguments are applied for
  1885. all the plots. If you want to set these options separately, you can index
  1886. the returned ``Plot`` object and set it.
  1887. Arguments for ``Plot`` class:
  1888. ``title`` : str. Title of the plot.
  1889. ``size`` : (float, float), optional
  1890. A tuple in the form (width, height) in inches to specify the size of the
  1891. overall figure. The default value is set to ``None``, meaning the size will
  1892. be set by the default backend.
  1893. Examples
  1894. ========
  1895. .. plot::
  1896. :context: reset
  1897. :format: doctest
  1898. :include-source: True
  1899. >>> from sympy import symbols, cos, sin
  1900. >>> from sympy.plotting import plot3d_parametric_surface
  1901. >>> u, v = symbols('u v')
  1902. Single plot.
  1903. .. plot::
  1904. :context: close-figs
  1905. :format: doctest
  1906. :include-source: True
  1907. >>> plot3d_parametric_surface(cos(u + v), sin(u - v), u - v,
  1908. ... (u, -5, 5), (v, -5, 5))
  1909. Plot object containing:
  1910. [0]: parametric cartesian surface: (cos(u + v), sin(u - v), u - v) for u over (-5.0, 5.0) and v over (-5.0, 5.0)
  1911. See Also
  1912. ========
  1913. Plot, ParametricSurfaceSeries
  1914. """
  1915. args = list(map(sympify, args))
  1916. series = []
  1917. plot_expr = check_arguments(args, 3, 2)
  1918. series = [ParametricSurfaceSeries(*arg, **kwargs) for arg in plot_expr]
  1919. kwargs.setdefault("xlabel", "x")
  1920. kwargs.setdefault("ylabel", "y")
  1921. kwargs.setdefault("zlabel", "z")
  1922. plots = Plot(*series, **kwargs)
  1923. if show:
  1924. plots.show()
  1925. return plots
  1926. def plot_contour(*args, show=True, **kwargs):
  1927. """
  1928. Draws contour plot of a function
  1929. Usage
  1930. =====
  1931. Single plot
  1932. ``plot_contour(expr, range_x, range_y, **kwargs)``
  1933. If the ranges are not specified, then a default range of (-10, 10) is used.
  1934. Multiple plot with the same range.
  1935. ``plot_contour(expr1, expr2, range_x, range_y, **kwargs)``
  1936. If the ranges are not specified, then a default range of (-10, 10) is used.
  1937. Multiple plots with different ranges.
  1938. ``plot_contour((expr1, range_x, range_y), (expr2, range_x, range_y), ..., **kwargs)``
  1939. Ranges have to be specified for every expression.
  1940. Default range may change in the future if a more advanced default range
  1941. detection algorithm is implemented.
  1942. Arguments
  1943. =========
  1944. ``expr`` : Expression representing the function along x.
  1945. ``range_x``: (x, 0, 5), A 3-tuple denoting the range of the x
  1946. variable.
  1947. ``range_y``: (y, 0, 5), A 3-tuple denoting the range of the y
  1948. variable.
  1949. Keyword Arguments
  1950. =================
  1951. Arguments for ``ContourSeries`` class:
  1952. ``nb_of_points_x``: int. The x range is sampled uniformly at
  1953. ``nb_of_points_x`` of points.
  1954. ``nb_of_points_y``: int. The y range is sampled uniformly at
  1955. ``nb_of_points_y`` of points.
  1956. Aesthetics:
  1957. ``surface_color``: Function which returns a float. Specifies the color for
  1958. the surface of the plot. See ``sympy.plotting.Plot`` for more details.
  1959. If there are multiple plots, then the same series arguments are applied to
  1960. all the plots. If you want to set these options separately, you can index
  1961. the returned ``Plot`` object and set it.
  1962. Arguments for ``Plot`` class:
  1963. ``title`` : str. Title of the plot.
  1964. ``size`` : (float, float), optional
  1965. A tuple in the form (width, height) in inches to specify the size of
  1966. the overall figure. The default value is set to ``None``, meaning
  1967. the size will be set by the default backend.
  1968. See Also
  1969. ========
  1970. Plot, ContourSeries
  1971. """
  1972. args = list(map(sympify, args))
  1973. plot_expr = check_arguments(args, 1, 2)
  1974. series = [ContourSeries(*arg) for arg in plot_expr]
  1975. plot_contours = Plot(*series, **kwargs)
  1976. if len(plot_expr[0].free_symbols) > 2:
  1977. raise ValueError('Contour Plot cannot Plot for more than two variables.')
  1978. if show:
  1979. plot_contours.show()
  1980. return plot_contours
  1981. def check_arguments(args, expr_len, nb_of_free_symbols):
  1982. """
  1983. Checks the arguments and converts into tuples of the
  1984. form (exprs, ranges).
  1985. Examples
  1986. ========
  1987. .. plot::
  1988. :context: reset
  1989. :format: doctest
  1990. :include-source: True
  1991. >>> from sympy import cos, sin, symbols
  1992. >>> from sympy.plotting.plot import check_arguments
  1993. >>> x = symbols('x')
  1994. >>> check_arguments([cos(x), sin(x)], 2, 1)
  1995. [(cos(x), sin(x), (x, -10, 10))]
  1996. >>> check_arguments([x, x**2], 1, 1)
  1997. [(x, (x, -10, 10)), (x**2, (x, -10, 10))]
  1998. """
  1999. if not args:
  2000. return []
  2001. if expr_len > 1 and isinstance(args[0], Expr):
  2002. # Multiple expressions same range.
  2003. # The arguments are tuples when the expression length is
  2004. # greater than 1.
  2005. if len(args) < expr_len:
  2006. raise ValueError("len(args) should not be less than expr_len")
  2007. for i in range(len(args)):
  2008. if isinstance(args[i], Tuple):
  2009. break
  2010. else:
  2011. i = len(args) + 1
  2012. exprs = Tuple(*args[:i])
  2013. free_symbols = list(set().union(*[e.free_symbols for e in exprs]))
  2014. if len(args) == expr_len + nb_of_free_symbols:
  2015. #Ranges given
  2016. plots = [exprs + Tuple(*args[expr_len:])]
  2017. else:
  2018. default_range = Tuple(-10, 10)
  2019. ranges = []
  2020. for symbol in free_symbols:
  2021. ranges.append(Tuple(symbol) + default_range)
  2022. for i in range(len(free_symbols) - nb_of_free_symbols):
  2023. ranges.append(Tuple(Dummy()) + default_range)
  2024. plots = [exprs + Tuple(*ranges)]
  2025. return plots
  2026. if isinstance(args[0], Expr) or (isinstance(args[0], Tuple) and
  2027. len(args[0]) == expr_len and
  2028. expr_len != 3):
  2029. # Cannot handle expressions with number of expression = 3. It is
  2030. # not possible to differentiate between expressions and ranges.
  2031. #Series of plots with same range
  2032. for i in range(len(args)):
  2033. if isinstance(args[i], Tuple) and len(args[i]) != expr_len:
  2034. break
  2035. if not isinstance(args[i], Tuple):
  2036. args[i] = Tuple(args[i])
  2037. else:
  2038. i = len(args) + 1
  2039. exprs = args[:i]
  2040. assert all(isinstance(e, Expr) for expr in exprs for e in expr)
  2041. free_symbols = list(set().union(*[e.free_symbols for expr in exprs
  2042. for e in expr]))
  2043. if len(free_symbols) > nb_of_free_symbols:
  2044. raise ValueError("The number of free_symbols in the expression "
  2045. "is greater than %d" % nb_of_free_symbols)
  2046. if len(args) == i + nb_of_free_symbols and isinstance(args[i], Tuple):
  2047. ranges = Tuple(*[range_expr for range_expr in args[
  2048. i:i + nb_of_free_symbols]])
  2049. plots = [expr + ranges for expr in exprs]
  2050. return plots
  2051. else:
  2052. # Use default ranges.
  2053. default_range = Tuple(-10, 10)
  2054. ranges = []
  2055. for symbol in free_symbols:
  2056. ranges.append(Tuple(symbol) + default_range)
  2057. for i in range(nb_of_free_symbols - len(free_symbols)):
  2058. ranges.append(Tuple(Dummy()) + default_range)
  2059. ranges = Tuple(*ranges)
  2060. plots = [expr + ranges for expr in exprs]
  2061. return plots
  2062. elif isinstance(args[0], Tuple) and len(args[0]) == expr_len + nb_of_free_symbols:
  2063. # Multiple plots with different ranges.
  2064. for arg in args:
  2065. for i in range(expr_len):
  2066. if not isinstance(arg[i], Expr):
  2067. raise ValueError("Expected an expression, given %s" %
  2068. str(arg[i]))
  2069. for i in range(nb_of_free_symbols):
  2070. if not len(arg[i + expr_len]) == 3:
  2071. raise ValueError("The ranges should be a tuple of "
  2072. "length 3, got %s" % str(arg[i + expr_len]))
  2073. return args