ast.py 55 KB


  1. """
  2. Types used to represent a full function/module as an Abstract Syntax Tree.
  3. Most types are small, and are merely used as tokens in the AST. A tree diagram
  4. has been included below to illustrate the relationships between the AST types.
  5. AST Type Tree
  6. -------------
  7. ::
  8. *Basic*
  9. |
  10. |
  11. CodegenAST
  12. |
  13. |--->AssignmentBase
  14. | |--->Assignment
  15. | |--->AugmentedAssignment
  16. | |--->AddAugmentedAssignment
  17. | |--->SubAugmentedAssignment
  18. | |--->MulAugmentedAssignment
  19. | |--->DivAugmentedAssignment
  20. | |--->ModAugmentedAssignment
  21. |
  22. |--->CodeBlock
  23. |
  24. |
  25. |--->Token
  26. |--->Attribute
  27. |--->For
  28. |--->String
  29. | |--->QuotedString
  30. | |--->Comment
  31. |--->Type
  32. | |--->IntBaseType
  33. | | |--->_SizedIntType
  34. | | |--->SignedIntType
  35. | | |--->UnsignedIntType
  36. | |--->FloatBaseType
  37. | |--->FloatType
  38. | |--->ComplexBaseType
  39. | |--->ComplexType
  40. |--->Node
  41. | |--->Variable
  42. | | |---> Pointer
  43. | |--->FunctionPrototype
  44. | |--->FunctionDefinition
  45. |--->Element
  46. |--->Declaration
  47. |--->While
  48. |--->Scope
  49. |--->Stream
  50. |--->Print
  51. |--->FunctionCall
  52. |--->BreakToken
  53. |--->ContinueToken
  54. |--->NoneToken
  55. |--->Return
  56. Predefined types
  57. ----------------
  58. A number of ``Type`` instances are provided in the ``sympy.codegen.ast`` module
  59. for convenience. Perhaps the two most common ones for code-generation (of numeric
  60. codes) are ``float32`` and ``float64`` (known as single and double precision respectively).
  61. There are also precision generic versions of Types (for which the codeprinters selects the
  62. underlying data type at time of printing): ``real``, ``integer``, ``complex_``, ``bool_``.
  63. The other ``Type`` instances defined are:
  64. - ``intc``: Integer type used by C's "int".
  65. - ``intp``: Integer type used by C's "unsigned".
  66. - ``int8``, ``int16``, ``int32``, ``int64``: n-bit integers.
  67. - ``uint8``, ``uint16``, ``uint32``, ``uint64``: n-bit unsigned integers.
  68. - ``float80``: known as "extended precision" on modern x86/amd64 hardware.
  69. - ``complex64``: Complex number represented by two ``float32`` numbers
  70. - ``complex128``: Complex number represented by two ``float64`` numbers
  71. Using the nodes
  72. ---------------
  73. It is possible to construct simple algorithms using the AST nodes. Let's construct a loop applying
  74. Newton's method::
  75. >>> from sympy import symbols, cos
  76. >>> from sympy.codegen.ast import While, Assignment, aug_assign, Print
  77. >>> t, dx, x = symbols('tol delta val')
  78. >>> expr = cos(x) - x**3
  79. >>> whl = While(abs(dx) > t, [
  80. ... Assignment(dx, -expr/expr.diff(x)),
  81. ... aug_assign(x, '+', dx),
  82. ... Print([x])
  83. ... ])
  84. >>> from sympy import pycode
  85. >>> py_str = pycode(whl)
  86. >>> print(py_str)
  87. while (abs(delta) > tol):
  88. delta = (val**3 - math.cos(val))/(-3*val**2 - math.sin(val))
  89. val += delta
  90. print(val)
  91. >>> import math
  92. >>> tol, val, delta = 1e-5, 0.5, float('inf')
  93. >>> exec(py_str)
  94. 1.1121416371
  95. 0.909672693737
  96. 0.867263818209
  97. 0.865477135298
  98. 0.865474033111
  99. >>> print('%3.1g' % (math.cos(val) - val**3))
  100. -3e-11
  101. If we want to generate Fortran code for the same while loop we simple call ``fcode``::
  102. >>> from sympy import fcode
  103. >>> print(fcode(whl, standard=2003, source_format='free'))
  104. do while (abs(delta) > tol)
  105. delta = (val**3 - cos(val))/(-3*val**2 - sin(val))
  106. val = val + delta
  107. print *, val
  108. end do
  109. There is a function constructing a loop (or a complete function) like this in
  110. :mod:`sympy.codegen.algorithms`.
  111. """
  112. from typing import Any, Dict as tDict, List
  113. from collections import defaultdict
  114. from sympy.core.relational import (Ge, Gt, Le, Lt)
  115. from sympy.core import Symbol, Tuple, Dummy
  116. from sympy.core.basic import Basic
  117. from sympy.core.expr import Expr, Atom
  118. from sympy.core.numbers import Float, Integer, oo
  119. from sympy.core.sympify import _sympify, sympify, SympifyError
  120. from sympy.utilities.iterables import (iterable, topological_sort,
  121. numbered_symbols, filter_symbols)
  122. def _mk_Tuple(args):
  123. """
  124. Create a SymPy Tuple object from an iterable, converting Python strings to
  125. AST strings.
  126. Parameters
  127. ==========
  128. args: iterable
  129. Arguments to :class:`sympy.Tuple`.
  130. Returns
  131. =======
  132. sympy.Tuple
  133. """
  134. args = [String(arg) if isinstance(arg, str) else arg for arg in args]
  135. return Tuple(*args)
  136. class CodegenAST(Basic):
  137. pass
  138. class Token(CodegenAST):
  139. """ Base class for the AST types.
  140. Explanation
  141. ===========
  142. Defining fields are set in ``__slots__``. Attributes (defined in __slots__)
  143. are only allowed to contain instances of Basic (unless atomic, see
  144. ``String``). The arguments to ``__new__()`` correspond to the attributes in
  145. the order defined in ``__slots__`. The ``defaults`` class attribute is a
  146. dictionary mapping attribute names to their default values.
  147. Subclasses should not need to override the ``__new__()`` method. They may
  148. define a class or static method named ``_construct_<attr>`` for each
  149. attribute to process the value passed to ``__new__()``. Attributes listed
  150. in the class attribute ``not_in_args`` are not passed to :class:`~.Basic`.
  151. """
  152. __slots__ = ()
  153. defaults = {} # type: tDict[str, Any]
  154. not_in_args = [] # type: List[str]
  155. indented_args = ['body']
  156. @property
  157. def is_Atom(self):
  158. return len(self.__slots__) == 0
  159. @classmethod
  160. def _get_constructor(cls, attr):
  161. """ Get the constructor function for an attribute by name. """
  162. return getattr(cls, '_construct_%s' % attr, lambda x: x)
  163. @classmethod
  164. def _construct(cls, attr, arg):
  165. """ Construct an attribute value from argument passed to ``__new__()``. """
  166. # arg may be ``NoneToken()``, so comparation is done using == instead of ``is`` operator
  167. if arg == None:
  168. return cls.defaults.get(attr, none)
  169. else:
  170. if isinstance(arg, Dummy): # SymPy's replace uses Dummy instances
  171. return arg
  172. else:
  173. return cls._get_constructor(attr)(arg)
  174. def __new__(cls, *args, **kwargs):
  175. # Pass through existing instances when given as sole argument
  176. if len(args) == 1 and not kwargs and isinstance(args[0], cls):
  177. return args[0]
  178. if len(args) > len(cls.__slots__):
  179. raise ValueError("Too many arguments (%d), expected at most %d" % (len(args), len(cls.__slots__)))
  180. attrvals = []
  181. # Process positional arguments
  182. for attrname, argval in zip(cls.__slots__, args):
  183. if attrname in kwargs:
  184. raise TypeError('Got multiple values for attribute %r' % attrname)
  185. attrvals.append(cls._construct(attrname, argval))
  186. # Process keyword arguments
  187. for attrname in cls.__slots__[len(args):]:
  188. if attrname in kwargs:
  189. argval = kwargs.pop(attrname)
  190. elif attrname in cls.defaults:
  191. argval = cls.defaults[attrname]
  192. else:
  193. raise TypeError('No value for %r given and attribute has no default' % attrname)
  194. attrvals.append(cls._construct(attrname, argval))
  195. if kwargs:
  196. raise ValueError("Unknown keyword arguments: %s" % ' '.join(kwargs))
  197. # Parent constructor
  198. basic_args = [
  199. val for attr, val in zip(cls.__slots__, attrvals)
  200. if attr not in cls.not_in_args
  201. ]
  202. obj = CodegenAST.__new__(cls, *basic_args)
  203. # Set attributes
  204. for attr, arg in zip(cls.__slots__, attrvals):
  205. setattr(obj, attr, arg)
  206. return obj
  207. def __eq__(self, other):
  208. if not isinstance(other, self.__class__):
  209. return False
  210. for attr in self.__slots__:
  211. if getattr(self, attr) != getattr(other, attr):
  212. return False
  213. return True
  214. def _hashable_content(self):
  215. return tuple([getattr(self, attr) for attr in self.__slots__])
  216. def __hash__(self):
  217. return super().__hash__()
  218. def _joiner(self, k, indent_level):
  219. return (',\n' + ' '*indent_level) if k in self.indented_args else ', '
  220. def _indented(self, printer, k, v, *args, **kwargs):
  221. il = printer._context['indent_level']
  222. def _print(arg):
  223. if isinstance(arg, Token):
  224. return printer._print(arg, *args, joiner=self._joiner(k, il), **kwargs)
  225. else:
  226. return printer._print(arg, *args, **kwargs)
  227. if isinstance(v, Tuple):
  228. joined = self._joiner(k, il).join([_print(arg) for arg in v.args])
  229. if k in self.indented_args:
  230. return '(\n' + ' '*il + joined + ',\n' + ' '*(il - 4) + ')'
  231. else:
  232. return ('({0},)' if len(v.args) == 1 else '({0})').format(joined)
  233. else:
  234. return _print(v)
  235. def _sympyrepr(self, printer, *args, joiner=', ', **kwargs):
  236. from sympy.printing.printer import printer_context
  237. exclude = kwargs.get('exclude', ())
  238. values = [getattr(self, k) for k in self.__slots__]
  239. indent_level = printer._context.get('indent_level', 0)
  240. arg_reprs = []
  241. for i, (attr, value) in enumerate(zip(self.__slots__, values)):
  242. if attr in exclude:
  243. continue
  244. # Skip attributes which have the default value
  245. if attr in self.defaults and value == self.defaults[attr]:
  246. continue
  247. ilvl = indent_level + 4 if attr in self.indented_args else 0
  248. with printer_context(printer, indent_level=ilvl):
  249. indented = self._indented(printer, attr, value, *args, **kwargs)
  250. arg_reprs.append(('{1}' if i == 0 else '{0}={1}').format(attr, indented.lstrip()))
  251. return "{}({})".format(self.__class__.__name__, joiner.join(arg_reprs))
  252. _sympystr = _sympyrepr
  253. def __repr__(self): # sympy.core.Basic.__repr__ uses sstr
  254. from sympy.printing import srepr
  255. return srepr(self)
  256. def kwargs(self, exclude=(), apply=None):
  257. """ Get instance's attributes as dict of keyword arguments.
  258. Parameters
  259. ==========
  260. exclude : collection of str
  261. Collection of keywords to exclude.
  262. apply : callable, optional
  263. Function to apply to all values.
  264. """
  265. kwargs = {k: getattr(self, k) for k in self.__slots__ if k not in exclude}
  266. if apply is not None:
  267. return {k: apply(v) for k, v in kwargs.items()}
  268. else:
  269. return kwargs
  270. class BreakToken(Token):
  271. """ Represents 'break' in C/Python ('exit' in Fortran).
  272. Use the premade instance ``break_`` or instantiate manually.
  273. Examples
  274. ========
  275. >>> from sympy import ccode, fcode
  276. >>> from sympy.codegen.ast import break_
  277. >>> ccode(break_)
  278. 'break'
  279. >>> fcode(break_, source_format='free')
  280. 'exit'
  281. """
  282. break_ = BreakToken()
  283. class ContinueToken(Token):
  284. """ Represents 'continue' in C/Python ('cycle' in Fortran)
  285. Use the premade instance ``continue_`` or instantiate manually.
  286. Examples
  287. ========
  288. >>> from sympy import ccode, fcode
  289. >>> from sympy.codegen.ast import continue_
  290. >>> ccode(continue_)
  291. 'continue'
  292. >>> fcode(continue_, source_format='free')
  293. 'cycle'
  294. """
  295. continue_ = ContinueToken()
  296. class NoneToken(Token):
  297. """ The AST equivalence of Python's NoneType
  298. The corresponding instance of Python's ``None`` is ``none``.
  299. Examples
  300. ========
  301. >>> from sympy.codegen.ast import none, Variable
  302. >>> from sympy import pycode
  303. >>> print(pycode(Variable('x').as_Declaration(value=none)))
  304. x = None
  305. """
  306. def __eq__(self, other):
  307. return other is None or isinstance(other, NoneToken)
  308. def _hashable_content(self):
  309. return ()
  310. def __hash__(self):
  311. return super().__hash__()
  312. none = NoneToken()
  313. class AssignmentBase(CodegenAST):
  314. """ Abstract base class for Assignment and AugmentedAssignment.
  315. Attributes:
  316. ===========
  317. op : str
  318. Symbol for assignment operator, e.g. "=", "+=", etc.
  319. """
  320. def __new__(cls, lhs, rhs):
  321. lhs = _sympify(lhs)
  322. rhs = _sympify(rhs)
  323. cls._check_args(lhs, rhs)
  324. return super().__new__(cls, lhs, rhs)
  325. @property
  326. def lhs(self):
  327. return self.args[0]
  328. @property
  329. def rhs(self):
  330. return self.args[1]
  331. @classmethod
  332. def _check_args(cls, lhs, rhs):
  333. """ Check arguments to __new__ and raise exception if any problems found.
  334. Derived classes may wish to override this.
  335. """
  336. from sympy.matrices.expressions.matexpr import (
  337. MatrixElement, MatrixSymbol)
  338. from sympy.tensor.indexed import Indexed
  339. # Tuple of things that can be on the lhs of an assignment
  340. assignable = (Symbol, MatrixSymbol, MatrixElement, Indexed, Element, Variable)
  341. if not isinstance(lhs, assignable):
  342. raise TypeError("Cannot assign to lhs of type %s." % type(lhs))
  343. # Indexed types implement shape, but don't define it until later. This
  344. # causes issues in assignment validation. For now, matrices are defined
  345. # as anything with a shape that is not an Indexed
  346. lhs_is_mat = hasattr(lhs, 'shape') and not isinstance(lhs, Indexed)
  347. rhs_is_mat = hasattr(rhs, 'shape') and not isinstance(rhs, Indexed)
  348. # If lhs and rhs have same structure, then this assignment is ok
  349. if lhs_is_mat:
  350. if not rhs_is_mat:
  351. raise ValueError("Cannot assign a scalar to a matrix.")
  352. elif lhs.shape != rhs.shape:
  353. raise ValueError("Dimensions of lhs and rhs do not align.")
  354. elif rhs_is_mat and not lhs_is_mat:
  355. raise ValueError("Cannot assign a matrix to a scalar.")
  356. class Assignment(AssignmentBase):
  357. """
  358. Represents variable assignment for code generation.
  359. Parameters
  360. ==========
  361. lhs : Expr
  362. SymPy object representing the lhs of the expression. These should be
  363. singular objects, such as one would use in writing code. Notable types
  364. include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that
  365. subclass these types are also supported.
  366. rhs : Expr
  367. SymPy object representing the rhs of the expression. This can be any
  368. type, provided its shape corresponds to that of the lhs. For example,
  369. a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as
  370. the dimensions will not align.
  371. Examples
  372. ========
  373. >>> from sympy import symbols, MatrixSymbol, Matrix
  374. >>> from sympy.codegen.ast import Assignment
  375. >>> x, y, z = symbols('x, y, z')
  376. >>> Assignment(x, y)
  377. Assignment(x, y)
  378. >>> Assignment(x, 0)
  379. Assignment(x, 0)
  380. >>> A = MatrixSymbol('A', 1, 3)
  381. >>> mat = Matrix([x, y, z]).T
  382. >>> Assignment(A, mat)
  383. Assignment(A, Matrix([[x, y, z]]))
  384. >>> Assignment(A[0, 1], x)
  385. Assignment(A[0, 1], x)
  386. """
  387. op = ':='
  388. class AugmentedAssignment(AssignmentBase):
  389. """
  390. Base class for augmented assignments.
  391. Attributes:
  392. ===========
  393. binop : str
  394. Symbol for binary operation being applied in the assignment, such as "+",
  395. "*", etc.
  396. """
  397. binop = None # type: str
  398. @property
  399. def op(self):
  400. return self.binop + '='
  401. class AddAugmentedAssignment(AugmentedAssignment):
  402. binop = '+'
  403. class SubAugmentedAssignment(AugmentedAssignment):
  404. binop = '-'
  405. class MulAugmentedAssignment(AugmentedAssignment):
  406. binop = '*'
  407. class DivAugmentedAssignment(AugmentedAssignment):
  408. binop = '/'
  409. class ModAugmentedAssignment(AugmentedAssignment):
  410. binop = '%'
  411. # Mapping from binary op strings to AugmentedAssignment subclasses
  412. augassign_classes = {
  413. cls.binop: cls for cls in [
  414. AddAugmentedAssignment, SubAugmentedAssignment, MulAugmentedAssignment,
  415. DivAugmentedAssignment, ModAugmentedAssignment
  416. ]
  417. }
  418. def aug_assign(lhs, op, rhs):
  419. """
  420. Create 'lhs op= rhs'.
  421. Explanation
  422. ===========
  423. Represents augmented variable assignment for code generation. This is a
  424. convenience function. You can also use the AugmentedAssignment classes
  425. directly, like AddAugmentedAssignment(x, y).
  426. Parameters
  427. ==========
  428. lhs : Expr
  429. SymPy object representing the lhs of the expression. These should be
  430. singular objects, such as one would use in writing code. Notable types
  431. include Symbol, MatrixSymbol, MatrixElement, and Indexed. Types that
  432. subclass these types are also supported.
  433. op : str
  434. Operator (+, -, /, \\*, %).
  435. rhs : Expr
  436. SymPy object representing the rhs of the expression. This can be any
  437. type, provided its shape corresponds to that of the lhs. For example,
  438. a Matrix type can be assigned to MatrixSymbol, but not to Symbol, as
  439. the dimensions will not align.
  440. Examples
  441. ========
  442. >>> from sympy import symbols
  443. >>> from sympy.codegen.ast import aug_assign
  444. >>> x, y = symbols('x, y')
  445. >>> aug_assign(x, '+', y)
  446. AddAugmentedAssignment(x, y)
  447. """
  448. if op not in augassign_classes:
  449. raise ValueError("Unrecognized operator %s" % op)
  450. return augassign_classes[op](lhs, rhs)
  451. class CodeBlock(CodegenAST):
  452. """
  453. Represents a block of code.
  454. Explanation
  455. ===========
  456. For now only assignments are supported. This restriction will be lifted in
  457. the future.
  458. Useful attributes on this object are:
  459. ``left_hand_sides``:
  460. Tuple of left-hand sides of assignments, in order.
  461. ``left_hand_sides``:
  462. Tuple of right-hand sides of assignments, in order.
  463. ``free_symbols``: Free symbols of the expressions in the right-hand sides
  464. which do not appear in the left-hand side of an assignment.
  465. Useful methods on this object are:
  466. ``topological_sort``:
  467. Class method. Return a CodeBlock with assignments
  468. sorted so that variables are assigned before they
  469. are used.
  470. ``cse``:
  471. Return a new CodeBlock with common subexpressions eliminated and
  472. pulled out as assignments.
  473. Examples
  474. ========
  475. >>> from sympy import symbols, ccode
  476. >>> from sympy.codegen.ast import CodeBlock, Assignment
  477. >>> x, y = symbols('x y')
  478. >>> c = CodeBlock(Assignment(x, 1), Assignment(y, x + 1))
  479. >>> print(ccode(c))
  480. x = 1;
  481. y = x + 1;
  482. """
  483. def __new__(cls, *args):
  484. left_hand_sides = []
  485. right_hand_sides = []
  486. for i in args:
  487. if isinstance(i, Assignment):
  488. lhs, rhs = i.args
  489. left_hand_sides.append(lhs)
  490. right_hand_sides.append(rhs)
  491. obj = CodegenAST.__new__(cls, *args)
  492. obj.left_hand_sides = Tuple(*left_hand_sides)
  493. obj.right_hand_sides = Tuple(*right_hand_sides)
  494. return obj
  495. def __iter__(self):
  496. return iter(self.args)
  497. def _sympyrepr(self, printer, *args, **kwargs):
  498. il = printer._context.get('indent_level', 0)
  499. joiner = ',\n' + ' '*il
  500. joined = joiner.join(map(printer._print, self.args))
  501. return ('{}(\n'.format(' '*(il-4) + self.__class__.__name__,) +
  502. ' '*il + joined + '\n' + ' '*(il - 4) + ')')
  503. _sympystr = _sympyrepr
  504. @property
  505. def free_symbols(self):
  506. return super().free_symbols - set(self.left_hand_sides)
  507. @classmethod
  508. def topological_sort(cls, assignments):
  509. """
  510. Return a CodeBlock with topologically sorted assignments so that
  511. variables are assigned before they are used.
  512. Examples
  513. ========
  514. The existing order of assignments is preserved as much as possible.
  515. This function assumes that variables are assigned to only once.
  516. This is a class constructor so that the default constructor for
  517. CodeBlock can error when variables are used before they are assigned.
  518. Examples
  519. ========
  520. >>> from sympy import symbols
  521. >>> from sympy.codegen.ast import CodeBlock, Assignment
  522. >>> x, y, z = symbols('x y z')
  523. >>> assignments = [
  524. ... Assignment(x, y + z),
  525. ... Assignment(y, z + 1),
  526. ... Assignment(z, 2),
  527. ... ]
  528. >>> CodeBlock.topological_sort(assignments)
  529. CodeBlock(
  530. Assignment(z, 2),
  531. Assignment(y, z + 1),
  532. Assignment(x, y + z)
  533. )
  534. """
  535. if not all(isinstance(i, Assignment) for i in assignments):
  536. # Will support more things later
  537. raise NotImplementedError("CodeBlock.topological_sort only supports Assignments")
  538. if any(isinstance(i, AugmentedAssignment) for i in assignments):
  539. raise NotImplementedError("CodeBlock.topological_sort doesn't yet work with AugmentedAssignments")
  540. # Create a graph where the nodes are assignments and there is a directed edge
  541. # between nodes that use a variable and nodes that assign that
  542. # variable, like
  543. # [(x := 1, y := x + 1), (x := 1, z := y + z), (y := x + 1, z := y + z)]
  544. # If we then topologically sort these nodes, they will be in
  545. # assignment order, like
  546. # x := 1
  547. # y := x + 1
  548. # z := y + z
  549. # A = The nodes
  550. #
  551. # enumerate keeps nodes in the same order they are already in if
  552. # possible. It will also allow us to handle duplicate assignments to
  553. # the same variable when those are implemented.
  554. A = list(enumerate(assignments))
  555. # var_map = {variable: [nodes for which this variable is assigned to]}
  556. # like {x: [(1, x := y + z), (4, x := 2 * w)], ...}
  557. var_map = defaultdict(list)
  558. for node in A:
  559. i, a = node
  560. var_map[a.lhs].append(node)
  561. # E = Edges in the graph
  562. E = []
  563. for dst_node in A:
  564. i, a = dst_node
  565. for s in a.rhs.free_symbols:
  566. for src_node in var_map[s]:
  567. E.append((src_node, dst_node))
  568. ordered_assignments = topological_sort([A, E])
  569. # De-enumerate the result
  570. return cls(*[a for i, a in ordered_assignments])
  571. def cse(self, symbols=None, optimizations=None, postprocess=None,
  572. order='canonical'):
  573. """
  574. Return a new code block with common subexpressions eliminated.
  575. Explanation
  576. ===========
  577. See the docstring of :func:`sympy.simplify.cse_main.cse` for more
  578. information.
  579. Examples
  580. ========
  581. >>> from sympy import symbols, sin
  582. >>> from sympy.codegen.ast import CodeBlock, Assignment
  583. >>> x, y, z = symbols('x y z')
  584. >>> c = CodeBlock(
  585. ... Assignment(x, 1),
  586. ... Assignment(y, sin(x) + 1),
  587. ... Assignment(z, sin(x) - 1),
  588. ... )
  589. ...
  590. >>> c.cse()
  591. CodeBlock(
  592. Assignment(x, 1),
  593. Assignment(x0, sin(x)),
  594. Assignment(y, x0 + 1),
  595. Assignment(z, x0 - 1)
  596. )
  597. """
  598. from sympy.simplify.cse_main import cse
  599. # Check that the CodeBlock only contains assignments to unique variables
  600. if not all(isinstance(i, Assignment) for i in self.args):
  601. # Will support more things later
  602. raise NotImplementedError("CodeBlock.cse only supports Assignments")
  603. if any(isinstance(i, AugmentedAssignment) for i in self.args):
  604. raise NotImplementedError("CodeBlock.cse doesn't yet work with AugmentedAssignments")
  605. for i, lhs in enumerate(self.left_hand_sides):
  606. if lhs in self.left_hand_sides[:i]:
  607. raise NotImplementedError("Duplicate assignments to the same "
  608. "variable are not yet supported (%s)" % lhs)
  609. # Ensure new symbols for subexpressions do not conflict with existing
  610. existing_symbols = self.atoms(Symbol)
  611. if symbols is None:
  612. symbols = numbered_symbols()
  613. symbols = filter_symbols(symbols, existing_symbols)
  614. replacements, reduced_exprs = cse(list(self.right_hand_sides),
  615. symbols=symbols, optimizations=optimizations, postprocess=postprocess,
  616. order=order)
  617. new_block = [Assignment(var, expr) for var, expr in
  618. zip(self.left_hand_sides, reduced_exprs)]
  619. new_assignments = [Assignment(var, expr) for var, expr in replacements]
  620. return self.topological_sort(new_assignments + new_block)
  621. class For(Token):
  622. """Represents a 'for-loop' in the code.
  623. Expressions are of the form:
  624. "for target in iter:
  625. body..."
  626. Parameters
  627. ==========
  628. target : symbol
  629. iter : iterable
  630. body : CodeBlock or iterable
  631. ! When passed an iterable it is used to instantiate a CodeBlock.
  632. Examples
  633. ========
  634. >>> from sympy import symbols, Range
  635. >>> from sympy.codegen.ast import aug_assign, For
  636. >>> x, i, j, k = symbols('x i j k')
  637. >>> for_i = For(i, Range(10), [aug_assign(x, '+', i*j*k)])
  638. >>> for_i # doctest: -NORMALIZE_WHITESPACE
  639. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  640. AddAugmentedAssignment(x, i*j*k)
  641. ))
  642. >>> for_ji = For(j, Range(7), [for_i])
  643. >>> for_ji # doctest: -NORMALIZE_WHITESPACE
  644. For(j, iterable=Range(0, 7, 1), body=CodeBlock(
  645. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  646. AddAugmentedAssignment(x, i*j*k)
  647. ))
  648. ))
  649. >>> for_kji =For(k, Range(5), [for_ji])
  650. >>> for_kji # doctest: -NORMALIZE_WHITESPACE
  651. For(k, iterable=Range(0, 5, 1), body=CodeBlock(
  652. For(j, iterable=Range(0, 7, 1), body=CodeBlock(
  653. For(i, iterable=Range(0, 10, 1), body=CodeBlock(
  654. AddAugmentedAssignment(x, i*j*k)
  655. ))
  656. ))
  657. ))
  658. """
  659. __slots__ = ('target', 'iterable', 'body')
  660. _construct_target = staticmethod(_sympify)
  661. @classmethod
  662. def _construct_body(cls, itr):
  663. if isinstance(itr, CodeBlock):
  664. return itr
  665. else:
  666. return CodeBlock(*itr)
  667. @classmethod
  668. def _construct_iterable(cls, itr):
  669. if not iterable(itr):
  670. raise TypeError("iterable must be an iterable")
  671. if isinstance(itr, list): # _sympify errors on lists because they are mutable
  672. itr = tuple(itr)
  673. return _sympify(itr)
  674. class String(Atom, Token):
  675. """ SymPy object representing a string.
  676. Atomic object which is not an expression (as opposed to Symbol).
  677. Parameters
  678. ==========
  679. text : str
  680. Examples
  681. ========
  682. >>> from sympy.codegen.ast import String
  683. >>> f = String('foo')
  684. >>> f
  685. foo
  686. >>> str(f)
  687. 'foo'
  688. >>> f.text
  689. 'foo'
  690. >>> print(repr(f))
  691. String('foo')
  692. """
  693. __slots__ = ('text',)
  694. not_in_args = ['text']
  695. is_Atom = True
  696. @classmethod
  697. def _construct_text(cls, text):
  698. if not isinstance(text, str):
  699. raise TypeError("Argument text is not a string type.")
  700. return text
  701. def _sympystr(self, printer, *args, **kwargs):
  702. return self.text
  703. def kwargs(self, exclude = (), apply = None):
  704. return {}
  705. #to be removed when Atom is given a suitable func
  706. @property
  707. def func(self):
  708. return lambda: self
  709. def _latex(self, printer):
  710. from sympy.printing.latex import latex_escape
  711. return r'\texttt{{"{}"}}'.format(latex_escape(self.text))
  712. class QuotedString(String):
  713. """ Represents a string which should be printed with quotes. """
  714. class Comment(String):
  715. """ Represents a comment. """
  716. class Node(Token):
  717. """ Subclass of Token, carrying the attribute 'attrs' (Tuple)
  718. Examples
  719. ========
  720. >>> from sympy.codegen.ast import Node, value_const, pointer_const
  721. >>> n1 = Node([value_const])
  722. >>> n1.attr_params('value_const') # get the parameters of attribute (by name)
  723. ()
  724. >>> from sympy.codegen.fnodes import dimension
  725. >>> n2 = Node([value_const, dimension(5, 3)])
  726. >>> n2.attr_params(value_const) # get the parameters of attribute (by Attribute instance)
  727. ()
  728. >>> n2.attr_params('dimension') # get the parameters of attribute (by name)
  729. (5, 3)
  730. >>> n2.attr_params(pointer_const) is None
  731. True
  732. """
  733. __slots__ = ('attrs',)
  734. defaults = {'attrs': Tuple()} # type: tDict[str, Any]
  735. _construct_attrs = staticmethod(_mk_Tuple)
  736. def attr_params(self, looking_for):
  737. """ Returns the parameters of the Attribute with name ``looking_for`` in self.attrs """
  738. for attr in self.attrs:
  739. if str(attr.name) == str(looking_for):
  740. return attr.parameters
  741. class Type(Token):
  742. """ Represents a type.
  743. Explanation
  744. ===========
  745. The naming is a super-set of NumPy naming. Type has a classmethod
  746. ``from_expr`` which offer type deduction. It also has a method
  747. ``cast_check`` which casts the argument to its type, possibly raising an
  748. exception if rounding error is not within tolerances, or if the value is not
  749. representable by the underlying data type (e.g. unsigned integers).
  750. Parameters
  751. ==========
  752. name : str
  753. Name of the type, e.g. ``object``, ``int16``, ``float16`` (where the latter two
  754. would use the ``Type`` sub-classes ``IntType`` and ``FloatType`` respectively).
  755. If a ``Type`` instance is given, the said instance is returned.
  756. Examples
  757. ========
  758. >>> from sympy.codegen.ast import Type
  759. >>> t = Type.from_expr(42)
  760. >>> t
  761. integer
  762. >>> print(repr(t))
  763. IntBaseType(String('integer'))
  764. >>> from sympy.codegen.ast import uint8
  765. >>> uint8.cast_check(-1) # doctest: +ELLIPSIS
  766. Traceback (most recent call last):
  767. ...
  768. ValueError: Minimum value for data type bigger than new value.
  769. >>> from sympy.codegen.ast import float32
  770. >>> v6 = 0.123456
  771. >>> float32.cast_check(v6)
  772. 0.123456
  773. >>> v10 = 12345.67894
  774. >>> float32.cast_check(v10) # doctest: +ELLIPSIS
  775. Traceback (most recent call last):
  776. ...
  777. ValueError: Casting gives a significantly different value.
  778. >>> boost_mp50 = Type('boost::multiprecision::cpp_dec_float_50')
  779. >>> from sympy import cxxcode
  780. >>> from sympy.codegen.ast import Declaration, Variable
  781. >>> cxxcode(Declaration(Variable('x', type=boost_mp50)))
  782. 'boost::multiprecision::cpp_dec_float_50 x'
  783. References
  784. ==========
  785. .. [1] https://docs.scipy.org/doc/numpy/user/basics.types.html
  786. """
  787. __slots__ = ('name',)
  788. _construct_name = String
  789. def _sympystr(self, printer, *args, **kwargs):
  790. return str(self.name)
  791. @classmethod
  792. def from_expr(cls, expr):
  793. """ Deduces type from an expression or a ``Symbol``.
  794. Parameters
  795. ==========
  796. expr : number or SymPy object
  797. The type will be deduced from type or properties.
  798. Examples
  799. ========
  800. >>> from sympy.codegen.ast import Type, integer, complex_
  801. >>> Type.from_expr(2) == integer
  802. True
  803. >>> from sympy import Symbol
  804. >>> Type.from_expr(Symbol('z', complex=True)) == complex_
  805. True
  806. >>> Type.from_expr(sum) # doctest: +ELLIPSIS
  807. Traceback (most recent call last):
  808. ...
  809. ValueError: Could not deduce type from expr.
  810. Raises
  811. ======
  812. ValueError when type deduction fails.
  813. """
  814. if isinstance(expr, (float, Float)):
  815. return real
  816. if isinstance(expr, (int, Integer)) or getattr(expr, 'is_integer', False):
  817. return integer
  818. if getattr(expr, 'is_real', False):
  819. return real
  820. if isinstance(expr, complex) or getattr(expr, 'is_complex', False):
  821. return complex_
  822. if isinstance(expr, bool) or getattr(expr, 'is_Relational', False):
  823. return bool_
  824. else:
  825. raise ValueError("Could not deduce type from expr.")
  826. def _check(self, value):
  827. pass
  828. def cast_check(self, value, rtol=None, atol=0, precision_targets=None):
  829. """ Casts a value to the data type of the instance.
  830. Parameters
  831. ==========
  832. value : number
  833. rtol : floating point number
  834. Relative tolerance. (will be deduced if not given).
  835. atol : floating point number
  836. Absolute tolerance (in addition to ``rtol``).
  837. type_aliases : dict
  838. Maps substitutions for Type, e.g. {integer: int64, real: float32}
  839. Examples
  840. ========
  841. >>> from sympy.codegen.ast import integer, float32, int8
  842. >>> integer.cast_check(3.0) == 3
  843. True
  844. >>> float32.cast_check(1e-40) # doctest: +ELLIPSIS
  845. Traceback (most recent call last):
  846. ...
  847. ValueError: Minimum value for data type bigger than new value.
  848. >>> int8.cast_check(256) # doctest: +ELLIPSIS
  849. Traceback (most recent call last):
  850. ...
  851. ValueError: Maximum value for data type smaller than new value.
  852. >>> v10 = 12345.67894
  853. >>> float32.cast_check(v10) # doctest: +ELLIPSIS
  854. Traceback (most recent call last):
  855. ...
  856. ValueError: Casting gives a significantly different value.
  857. >>> from sympy.codegen.ast import float64
  858. >>> float64.cast_check(v10)
  859. 12345.67894
  860. >>> from sympy import Float
  861. >>> v18 = Float('0.123456789012345646')
  862. >>> float64.cast_check(v18)
  863. Traceback (most recent call last):
  864. ...
  865. ValueError: Casting gives a significantly different value.
  866. >>> from sympy.codegen.ast import float80
  867. >>> float80.cast_check(v18)
  868. 0.123456789012345649
  869. """
  870. val = sympify(value)
  871. ten = Integer(10)
  872. exp10 = getattr(self, 'decimal_dig', None)
  873. if rtol is None:
  874. rtol = 1e-15 if exp10 is None else 2.0*ten**(-exp10)
  875. def tol(num):
  876. return atol + rtol*abs(num)
  877. new_val = self.cast_nocheck(value)
  878. self._check(new_val)
  879. delta = new_val - val
  880. if abs(delta) > tol(val): # rounding, e.g. int(3.5) != 3.5
  881. raise ValueError("Casting gives a significantly different value.")
  882. return new_val
  883. def _latex(self, printer):
  884. from sympy.printing.latex import latex_escape
  885. type_name = latex_escape(self.__class__.__name__)
  886. name = latex_escape(self.name.text)
  887. return r"\text{{{}}}\left(\texttt{{{}}}\right)".format(type_name, name)
  888. class IntBaseType(Type):
  889. """ Integer base type, contains no size information. """
  890. __slots__ = ('name',)
  891. cast_nocheck = lambda self, i: Integer(int(i))
  892. class _SizedIntType(IntBaseType):
  893. __slots__ = ('name', 'nbits',)
  894. _construct_nbits = Integer
  895. def _check(self, value):
  896. if value < self.min:
  897. raise ValueError("Value is too small: %d < %d" % (value, self.min))
  898. if value > self.max:
  899. raise ValueError("Value is too big: %d > %d" % (value, self.max))
  900. class SignedIntType(_SizedIntType):
  901. """ Represents a signed integer type. """
  902. @property
  903. def min(self):
  904. return -2**(self.nbits-1)
  905. @property
  906. def max(self):
  907. return 2**(self.nbits-1) - 1
  908. class UnsignedIntType(_SizedIntType):
  909. """ Represents an unsigned integer type. """
  910. @property
  911. def min(self):
  912. return 0
  913. @property
  914. def max(self):
  915. return 2**self.nbits - 1
  916. two = Integer(2)
  917. class FloatBaseType(Type):
  918. """ Represents a floating point number type. """
  919. cast_nocheck = Float
  920. class FloatType(FloatBaseType):
  921. """ Represents a floating point type with fixed bit width.
  922. Base 2 & one sign bit is assumed.
  923. Parameters
  924. ==========
  925. name : str
  926. Name of the type.
  927. nbits : integer
  928. Number of bits used (storage).
  929. nmant : integer
  930. Number of bits used to represent the mantissa.
  931. nexp : integer
  932. Number of bits used to represent the mantissa.
  933. Examples
  934. ========
  935. >>> from sympy import S
  936. >>> from sympy.codegen.ast import FloatType
  937. >>> half_precision = FloatType('f16', nbits=16, nmant=10, nexp=5)
  938. >>> half_precision.max
  939. 65504
  940. >>> half_precision.tiny == S(2)**-14
  941. True
  942. >>> half_precision.eps == S(2)**-10
  943. True
  944. >>> half_precision.dig == 3
  945. True
  946. >>> half_precision.decimal_dig == 5
  947. True
  948. >>> half_precision.cast_check(1.0)
  949. 1.0
  950. >>> half_precision.cast_check(1e5) # doctest: +ELLIPSIS
  951. Traceback (most recent call last):
  952. ...
  953. ValueError: Maximum value for data type smaller than new value.
  954. """
  955. __slots__ = ('name', 'nbits', 'nmant', 'nexp',)
  956. _construct_nbits = _construct_nmant = _construct_nexp = Integer
  957. @property
  958. def max_exponent(self):
  959. """ The largest positive number n, such that 2**(n - 1) is a representable finite value. """
  960. # cf. C++'s ``std::numeric_limits::max_exponent``
  961. return two**(self.nexp - 1)
  962. @property
  963. def min_exponent(self):
  964. """ The lowest negative number n, such that 2**(n - 1) is a valid normalized number. """
  965. # cf. C++'s ``std::numeric_limits::min_exponent``
  966. return 3 - self.max_exponent
  967. @property
  968. def max(self):
  969. """ Maximum value representable. """
  970. return (1 - two**-(self.nmant+1))*two**self.max_exponent
  971. @property
  972. def tiny(self):
  973. """ The minimum positive normalized value. """
  974. # See C macros: FLT_MIN, DBL_MIN, LDBL_MIN
  975. # or C++'s ``std::numeric_limits::min``
  976. # or numpy.finfo(dtype).tiny
  977. return two**(self.min_exponent - 1)
  978. @property
  979. def eps(self):
  980. """ Difference between 1.0 and the next representable value. """
  981. return two**(-self.nmant)
  982. @property
  983. def dig(self):
  984. """ Number of decimal digits that are guaranteed to be preserved in text.
  985. When converting text -> float -> text, you are guaranteed that at least ``dig``
  986. number of digits are preserved with respect to rounding or overflow.
  987. """
  988. from sympy.functions import floor, log
  989. return floor(self.nmant * log(2)/log(10))
  990. @property
  991. def decimal_dig(self):
  992. """ Number of digits needed to store & load without loss.
  993. Explanation
  994. ===========
  995. Number of decimal digits needed to guarantee that two consecutive conversions
  996. (float -> text -> float) to be idempotent. This is useful when one do not want
  997. to loose precision due to rounding errors when storing a floating point value
  998. as text.
  999. """
  1000. from sympy.functions import ceiling, log
  1001. return ceiling((self.nmant + 1) * log(2)/log(10) + 1)
  1002. def cast_nocheck(self, value):
  1003. """ Casts without checking if out of bounds or subnormal. """
  1004. if value == oo: # float(oo) or oo
  1005. return float(oo)
  1006. elif value == -oo: # float(-oo) or -oo
  1007. return float(-oo)
  1008. return Float(str(sympify(value).evalf(self.decimal_dig)), self.decimal_dig)
  1009. def _check(self, value):
  1010. if value < -self.max:
  1011. raise ValueError("Value is too small: %d < %d" % (value, -self.max))
  1012. if value > self.max:
  1013. raise ValueError("Value is too big: %d > %d" % (value, self.max))
  1014. if abs(value) < self.tiny:
  1015. raise ValueError("Smallest (absolute) value for data type bigger than new value.")
  1016. class ComplexBaseType(FloatBaseType):
  1017. def cast_nocheck(self, value):
  1018. """ Casts without checking if out of bounds or subnormal. """
  1019. from sympy.functions import re, im
  1020. return (
  1021. super().cast_nocheck(re(value)) +
  1022. super().cast_nocheck(im(value))*1j
  1023. )
  1024. def _check(self, value):
  1025. from sympy.functions import re, im
  1026. super()._check(re(value))
  1027. super()._check(im(value))
  1028. class ComplexType(ComplexBaseType, FloatType):
  1029. """ Represents a complex floating point number. """
  1030. # NumPy types:
  1031. intc = IntBaseType('intc')
  1032. intp = IntBaseType('intp')
  1033. int8 = SignedIntType('int8', 8)
  1034. int16 = SignedIntType('int16', 16)
  1035. int32 = SignedIntType('int32', 32)
  1036. int64 = SignedIntType('int64', 64)
  1037. uint8 = UnsignedIntType('uint8', 8)
  1038. uint16 = UnsignedIntType('uint16', 16)
  1039. uint32 = UnsignedIntType('uint32', 32)
  1040. uint64 = UnsignedIntType('uint64', 64)
  1041. float16 = FloatType('float16', 16, nexp=5, nmant=10) # IEEE 754 binary16, Half precision
  1042. float32 = FloatType('float32', 32, nexp=8, nmant=23) # IEEE 754 binary32, Single precision
  1043. float64 = FloatType('float64', 64, nexp=11, nmant=52) # IEEE 754 binary64, Double precision
  1044. float80 = FloatType('float80', 80, nexp=15, nmant=63) # x86 extended precision (1 integer part bit), "long double"
  1045. float128 = FloatType('float128', 128, nexp=15, nmant=112) # IEEE 754 binary128, Quadruple precision
  1046. float256 = FloatType('float256', 256, nexp=19, nmant=236) # IEEE 754 binary256, Octuple precision
  1047. complex64 = ComplexType('complex64', nbits=64, **float32.kwargs(exclude=('name', 'nbits')))
  1048. complex128 = ComplexType('complex128', nbits=128, **float64.kwargs(exclude=('name', 'nbits')))
  1049. # Generic types (precision may be chosen by code printers):
  1050. untyped = Type('untyped')
  1051. real = FloatBaseType('real')
  1052. integer = IntBaseType('integer')
  1053. complex_ = ComplexBaseType('complex')
  1054. bool_ = Type('bool')
  1055. class Attribute(Token):
  1056. """ Attribute (possibly parametrized)
  1057. For use with :class:`sympy.codegen.ast.Node` (which takes instances of
  1058. ``Attribute`` as ``attrs``).
  1059. Parameters
  1060. ==========
  1061. name : str
  1062. parameters : Tuple
  1063. Examples
  1064. ========
  1065. >>> from sympy.codegen.ast import Attribute
  1066. >>> volatile = Attribute('volatile')
  1067. >>> volatile
  1068. volatile
  1069. >>> print(repr(volatile))
  1070. Attribute(String('volatile'))
  1071. >>> a = Attribute('foo', [1, 2, 3])
  1072. >>> a
  1073. foo(1, 2, 3)
  1074. >>> a.parameters == (1, 2, 3)
  1075. True
  1076. """
  1077. __slots__ = ('name', 'parameters')
  1078. defaults = {'parameters': Tuple()}
  1079. _construct_name = String
  1080. _construct_parameters = staticmethod(_mk_Tuple)
  1081. def _sympystr(self, printer, *args, **kwargs):
  1082. result = str(self.name)
  1083. if self.parameters:
  1084. result += '(%s)' % ', '.join(map(lambda arg: printer._print(
  1085. arg, *args, **kwargs), self.parameters))
  1086. return result
  1087. value_const = Attribute('value_const')
  1088. pointer_const = Attribute('pointer_const')
  1089. class Variable(Node):
  1090. """ Represents a variable.
  1091. Parameters
  1092. ==========
  1093. symbol : Symbol
  1094. type : Type (optional)
  1095. Type of the variable.
  1096. attrs : iterable of Attribute instances
  1097. Will be stored as a Tuple.
  1098. Examples
  1099. ========
  1100. >>> from sympy import Symbol
  1101. >>> from sympy.codegen.ast import Variable, float32, integer
  1102. >>> x = Symbol('x')
  1103. >>> v = Variable(x, type=float32)
  1104. >>> v.attrs
  1105. ()
  1106. >>> v == Variable('x')
  1107. False
  1108. >>> v == Variable('x', type=float32)
  1109. True
  1110. >>> v
  1111. Variable(x, type=float32)
  1112. One may also construct a ``Variable`` instance with the type deduced from
  1113. assumptions about the symbol using the ``deduced`` classmethod:
  1114. >>> i = Symbol('i', integer=True)
  1115. >>> v = Variable.deduced(i)
  1116. >>> v.type == integer
  1117. True
  1118. >>> v == Variable('i')
  1119. False
  1120. >>> from sympy.codegen.ast import value_const
  1121. >>> value_const in v.attrs
  1122. False
  1123. >>> w = Variable('w', attrs=[value_const])
  1124. >>> w
  1125. Variable(w, attrs=(value_const,))
  1126. >>> value_const in w.attrs
  1127. True
  1128. >>> w.as_Declaration(value=42)
  1129. Declaration(Variable(w, value=42, attrs=(value_const,)))
  1130. """
  1131. __slots__ = ('symbol', 'type', 'value') + Node.__slots__
  1132. defaults = Node.defaults.copy()
  1133. defaults.update({'type': untyped, 'value': none})
  1134. _construct_symbol = staticmethod(sympify)
  1135. _construct_value = staticmethod(sympify)
  1136. @classmethod
  1137. def deduced(cls, symbol, value=None, attrs=Tuple(), cast_check=True):
  1138. """ Alt. constructor with type deduction from ``Type.from_expr``.
  1139. Deduces type primarily from ``symbol``, secondarily from ``value``.
  1140. Parameters
  1141. ==========
  1142. symbol : Symbol
  1143. value : expr
  1144. (optional) value of the variable.
  1145. attrs : iterable of Attribute instances
  1146. cast_check : bool
  1147. Whether to apply ``Type.cast_check`` on ``value``.
  1148. Examples
  1149. ========
  1150. >>> from sympy import Symbol
  1151. >>> from sympy.codegen.ast import Variable, complex_
  1152. >>> n = Symbol('n', integer=True)
  1153. >>> str(Variable.deduced(n).type)
  1154. 'integer'
  1155. >>> x = Symbol('x', real=True)
  1156. >>> v = Variable.deduced(x)
  1157. >>> v.type
  1158. real
  1159. >>> z = Symbol('z', complex=True)
  1160. >>> Variable.deduced(z).type == complex_
  1161. True
  1162. """
  1163. if isinstance(symbol, Variable):
  1164. return symbol
  1165. try:
  1166. type_ = Type.from_expr(symbol)
  1167. except ValueError:
  1168. type_ = Type.from_expr(value)
  1169. if value is not None and cast_check:
  1170. value = type_.cast_check(value)
  1171. return cls(symbol, type=type_, value=value, attrs=attrs)
  1172. def as_Declaration(self, **kwargs):
  1173. """ Convenience method for creating a Declaration instance.
  1174. Explanation
  1175. ===========
  1176. If the variable of the Declaration need to wrap a modified
  1177. variable keyword arguments may be passed (overriding e.g.
  1178. the ``value`` of the Variable instance).
  1179. Examples
  1180. ========
  1181. >>> from sympy.codegen.ast import Variable, NoneToken
  1182. >>> x = Variable('x')
  1183. >>> decl1 = x.as_Declaration()
  1184. >>> # value is special NoneToken() which must be tested with == operator
  1185. >>> decl1.variable.value is None # won't work
  1186. False
  1187. >>> decl1.variable.value == None # not PEP-8 compliant
  1188. True
  1189. >>> decl1.variable.value == NoneToken() # OK
  1190. True
  1191. >>> decl2 = x.as_Declaration(value=42.0)
  1192. >>> decl2.variable.value == 42
  1193. True
  1194. """
  1195. kw = self.kwargs()
  1196. kw.update(kwargs)
  1197. return Declaration(self.func(**kw))
  1198. def _relation(self, rhs, op):
  1199. try:
  1200. rhs = _sympify(rhs)
  1201. except SympifyError:
  1202. raise TypeError("Invalid comparison %s < %s" % (self, rhs))
  1203. return op(self, rhs, evaluate=False)
  1204. __lt__ = lambda self, other: self._relation(other, Lt)
  1205. __le__ = lambda self, other: self._relation(other, Le)
  1206. __ge__ = lambda self, other: self._relation(other, Ge)
  1207. __gt__ = lambda self, other: self._relation(other, Gt)
  1208. class Pointer(Variable):
  1209. """ Represents a pointer. See ``Variable``.
  1210. Examples
  1211. ========
  1212. Can create instances of ``Element``:
  1213. >>> from sympy import Symbol
  1214. >>> from sympy.codegen.ast import Pointer
  1215. >>> i = Symbol('i', integer=True)
  1216. >>> p = Pointer('x')
  1217. >>> p[i+1]
  1218. Element(x, indices=(i + 1,))
  1219. """
  1220. def __getitem__(self, key):
  1221. try:
  1222. return Element(self.symbol, key)
  1223. except TypeError:
  1224. return Element(self.symbol, (key,))
  1225. class Element(Token):
  1226. """ Element in (a possibly N-dimensional) array.
  1227. Examples
  1228. ========
  1229. >>> from sympy.codegen.ast import Element
  1230. >>> elem = Element('x', 'ijk')
  1231. >>> elem.symbol.name == 'x'
  1232. True
  1233. >>> elem.indices
  1234. (i, j, k)
  1235. >>> from sympy import ccode
  1236. >>> ccode(elem)
  1237. 'x[i][j][k]'
  1238. >>> ccode(Element('x', 'ijk', strides='lmn', offset='o'))
  1239. 'x[i*l + j*m + k*n + o]'
  1240. """
  1241. __slots__ = ('symbol', 'indices', 'strides', 'offset')
  1242. defaults = {'strides': none, 'offset': none}
  1243. _construct_symbol = staticmethod(sympify)
  1244. _construct_indices = staticmethod(lambda arg: Tuple(*arg))
  1245. _construct_strides = staticmethod(lambda arg: Tuple(*arg))
  1246. _construct_offset = staticmethod(sympify)
  1247. class Declaration(Token):
  1248. """ Represents a variable declaration
  1249. Parameters
  1250. ==========
  1251. variable : Variable
  1252. Examples
  1253. ========
  1254. >>> from sympy.codegen.ast import Declaration, NoneToken, untyped
  1255. >>> z = Declaration('z')
  1256. >>> z.variable.type == untyped
  1257. True
  1258. >>> # value is special NoneToken() which must be tested with == operator
  1259. >>> z.variable.value is None # won't work
  1260. False
  1261. >>> z.variable.value == None # not PEP-8 compliant
  1262. True
  1263. >>> z.variable.value == NoneToken() # OK
  1264. True
  1265. """
  1266. __slots__ = ('variable',)
  1267. _construct_variable = Variable
  1268. class While(Token):
  1269. """ Represents a 'for-loop' in the code.
  1270. Expressions are of the form:
  1271. "while condition:
  1272. body..."
  1273. Parameters
  1274. ==========
  1275. condition : expression convertible to Boolean
  1276. body : CodeBlock or iterable
  1277. When passed an iterable it is used to instantiate a CodeBlock.
  1278. Examples
  1279. ========
  1280. >>> from sympy import symbols, Gt, Abs
  1281. >>> from sympy.codegen import aug_assign, Assignment, While
  1282. >>> x, dx = symbols('x dx')
  1283. >>> expr = 1 - x**2
  1284. >>> whl = While(Gt(Abs(dx), 1e-9), [
  1285. ... Assignment(dx, -expr/expr.diff(x)),
  1286. ... aug_assign(x, '+', dx)
  1287. ... ])
  1288. """
  1289. __slots__ = ('condition', 'body')
  1290. _construct_condition = staticmethod(lambda cond: _sympify(cond))
  1291. @classmethod
  1292. def _construct_body(cls, itr):
  1293. if isinstance(itr, CodeBlock):
  1294. return itr
  1295. else:
  1296. return CodeBlock(*itr)
  1297. class Scope(Token):
  1298. """ Represents a scope in the code.
  1299. Parameters
  1300. ==========
  1301. body : CodeBlock or iterable
  1302. When passed an iterable it is used to instantiate a CodeBlock.
  1303. """
  1304. __slots__ = ('body',)
  1305. @classmethod
  1306. def _construct_body(cls, itr):
  1307. if isinstance(itr, CodeBlock):
  1308. return itr
  1309. else:
  1310. return CodeBlock(*itr)
  1311. class Stream(Token):
  1312. """ Represents a stream.
  1313. There are two predefined Stream instances ``stdout`` & ``stderr``.
  1314. Parameters
  1315. ==========
  1316. name : str
  1317. Examples
  1318. ========
  1319. >>> from sympy import pycode, Symbol
  1320. >>> from sympy.codegen.ast import Print, stderr, QuotedString
  1321. >>> print(pycode(Print(['x'], file=stderr)))
  1322. print(x, file=sys.stderr)
  1323. >>> x = Symbol('x')
  1324. >>> print(pycode(Print([QuotedString('x')], file=stderr))) # print literally "x"
  1325. print("x", file=sys.stderr)
  1326. """
  1327. __slots__ = ('name',)
  1328. _construct_name = String
  1329. stdout = Stream('stdout')
  1330. stderr = Stream('stderr')
  1331. class Print(Token):
  1332. """ Represents print command in the code.
  1333. Parameters
  1334. ==========
  1335. formatstring : str
  1336. *args : Basic instances (or convertible to such through sympify)
  1337. Examples
  1338. ========
  1339. >>> from sympy.codegen.ast import Print
  1340. >>> from sympy import pycode
  1341. >>> print(pycode(Print('x y'.split(), "coordinate: %12.5g %12.5g")))
  1342. print("coordinate: %12.5g %12.5g" % (x, y))
  1343. """
  1344. __slots__ = ('print_args', 'format_string', 'file')
  1345. defaults = {'format_string': none, 'file': none}
  1346. _construct_print_args = staticmethod(_mk_Tuple)
  1347. _construct_format_string = QuotedString
  1348. _construct_file = Stream
  1349. class FunctionPrototype(Node):
  1350. """ Represents a function prototype
  1351. Allows the user to generate forward declaration in e.g. C/C++.
  1352. Parameters
  1353. ==========
  1354. return_type : Type
  1355. name : str
  1356. parameters: iterable of Variable instances
  1357. attrs : iterable of Attribute instances
  1358. Examples
  1359. ========
  1360. >>> from sympy import ccode, symbols
  1361. >>> from sympy.codegen.ast import real, FunctionPrototype
  1362. >>> x, y = symbols('x y', real=True)
  1363. >>> fp = FunctionPrototype(real, 'foo', [x, y])
  1364. >>> ccode(fp)
  1365. 'double foo(double x, double y)'
  1366. """
  1367. __slots__ = ('return_type', 'name', 'parameters', 'attrs')
  1368. _construct_return_type = Type
  1369. _construct_name = String
  1370. @staticmethod
  1371. def _construct_parameters(args):
  1372. def _var(arg):
  1373. if isinstance(arg, Declaration):
  1374. return arg.variable
  1375. elif isinstance(arg, Variable):
  1376. return arg
  1377. else:
  1378. return Variable.deduced(arg)
  1379. return Tuple(*map(_var, args))
  1380. @classmethod
  1381. def from_FunctionDefinition(cls, func_def):
  1382. if not isinstance(func_def, FunctionDefinition):
  1383. raise TypeError("func_def is not an instance of FunctionDefiniton")
  1384. return cls(**func_def.kwargs(exclude=('body',)))
  1385. class FunctionDefinition(FunctionPrototype):
  1386. """ Represents a function definition in the code.
  1387. Parameters
  1388. ==========
  1389. return_type : Type
  1390. name : str
  1391. parameters: iterable of Variable instances
  1392. body : CodeBlock or iterable
  1393. attrs : iterable of Attribute instances
  1394. Examples
  1395. ========
  1396. >>> from sympy import ccode, symbols
  1397. >>> from sympy.codegen.ast import real, FunctionPrototype
  1398. >>> x, y = symbols('x y', real=True)
  1399. >>> fp = FunctionPrototype(real, 'foo', [x, y])
  1400. >>> ccode(fp)
  1401. 'double foo(double x, double y)'
  1402. >>> from sympy.codegen.ast import FunctionDefinition, Return
  1403. >>> body = [Return(x*y)]
  1404. >>> fd = FunctionDefinition.from_FunctionPrototype(fp, body)
  1405. >>> print(ccode(fd))
  1406. double foo(double x, double y){
  1407. return x*y;
  1408. }
  1409. """
  1410. __slots__ = FunctionPrototype.__slots__[:-1] + ('body', 'attrs')
  1411. @classmethod
  1412. def _construct_body(cls, itr):
  1413. if isinstance(itr, CodeBlock):
  1414. return itr
  1415. else:
  1416. return CodeBlock(*itr)
  1417. @classmethod
  1418. def from_FunctionPrototype(cls, func_proto, body):
  1419. if not isinstance(func_proto, FunctionPrototype):
  1420. raise TypeError("func_proto is not an instance of FunctionPrototype")
  1421. return cls(body=body, **func_proto.kwargs())
  1422. class Return(Token):
  1423. """ Represents a return command in the code.
  1424. Parameters
  1425. ==========
  1426. return : Basic
  1427. Examples
  1428. ========
  1429. >>> from sympy.codegen.ast import Return
  1430. >>> from sympy.printing.pycode import pycode
  1431. >>> from sympy import Symbol
  1432. >>> x = Symbol('x')
  1433. >>> print(pycode(Return(x)))
  1434. return x
  1435. """
  1436. __slots__ = ('return',)
  1437. _construct_return=staticmethod(_sympify)
  1438. class FunctionCall(Token, Expr):
  1439. """ Represents a call to a function in the code.
  1440. Parameters
  1441. ==========
  1442. name : str
  1443. function_args : Tuple
  1444. Examples
  1445. ========
  1446. >>> from sympy.codegen.ast import FunctionCall
  1447. >>> from sympy import pycode
  1448. >>> fcall = FunctionCall('foo', 'bar baz'.split())
  1449. >>> print(pycode(fcall))
  1450. foo(bar, baz)
  1451. """
  1452. __slots__ = ('name', 'function_args')
  1453. _construct_name = String
  1454. _construct_function_args = staticmethod(lambda args: Tuple(*args))