fnodes.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650
  1. """
  2. AST nodes specific to Fortran.
  3. The functions defined in this module allows the user to express functions such as ``dsign``
  4. as a SymPy function for symbolic manipulation.
  5. """
  6. from sympy.codegen.ast import (
  7. Attribute, CodeBlock, FunctionCall, Node, none, String,
  8. Token, _mk_Tuple, Variable
  9. )
  10. from sympy.core.basic import Basic
  11. from sympy.core.containers import Tuple
  12. from sympy.core.expr import Expr
  13. from sympy.core.function import Function
  14. from sympy.core.numbers import Float, Integer
  15. from sympy.core.sympify import sympify
  16. from sympy.logic import true, false
  17. from sympy.utilities.iterables import iterable
  18. pure = Attribute('pure')
  19. elemental = Attribute('elemental') # (all elemental procedures are also pure)
  20. intent_in = Attribute('intent_in')
  21. intent_out = Attribute('intent_out')
  22. intent_inout = Attribute('intent_inout')
  23. allocatable = Attribute('allocatable')
  24. class Program(Token):
  25. """ Represents a 'program' block in Fortran.
  26. Examples
  27. ========
  28. >>> from sympy.codegen.ast import Print
  29. >>> from sympy.codegen.fnodes import Program
  30. >>> prog = Program('myprogram', [Print([42])])
  31. >>> from sympy import fcode
  32. >>> print(fcode(prog, source_format='free'))
  33. program myprogram
  34. print *, 42
  35. end program
  36. """
  37. __slots__ = ('name', 'body')
  38. _construct_name = String
  39. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  40. class use_rename(Token):
  41. """ Represents a renaming in a use statement in Fortran.
  42. Examples
  43. ========
  44. >>> from sympy.codegen.fnodes import use_rename, use
  45. >>> from sympy import fcode
  46. >>> ren = use_rename("thingy", "convolution2d")
  47. >>> print(fcode(ren, source_format='free'))
  48. thingy => convolution2d
  49. >>> full = use('signallib', only=['snr', ren])
  50. >>> print(fcode(full, source_format='free'))
  51. use signallib, only: snr, thingy => convolution2d
  52. """
  53. __slots__ = ('local', 'original')
  54. _construct_local = String
  55. _construct_original = String
  56. def _name(arg):
  57. if hasattr(arg, 'name'):
  58. return arg.name
  59. else:
  60. return String(arg)
  61. class use(Token):
  62. """ Represents a use statement in Fortran.
  63. Examples
  64. ========
  65. >>> from sympy.codegen.fnodes import use
  66. >>> from sympy import fcode
  67. >>> fcode(use('signallib'), source_format='free')
  68. 'use signallib'
  69. >>> fcode(use('signallib', [('metric', 'snr')]), source_format='free')
  70. 'use signallib, metric => snr'
  71. >>> fcode(use('signallib', only=['snr', 'convolution2d']), source_format='free')
  72. 'use signallib, only: snr, convolution2d'
  73. """
  74. __slots__ = ('namespace', 'rename', 'only')
  75. defaults = {'rename': none, 'only': none}
  76. _construct_namespace = staticmethod(_name)
  77. _construct_rename = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else use_rename(*arg) for arg in args]))
  78. _construct_only = staticmethod(lambda args: Tuple(*[arg if isinstance(arg, use_rename) else _name(arg) for arg in args]))
  79. class Module(Token):
  80. """ Represents a module in Fortran.
  81. Examples
  82. ========
  83. >>> from sympy.codegen.fnodes import Module
  84. >>> from sympy import fcode
  85. >>> print(fcode(Module('signallib', ['implicit none'], []), source_format='free'))
  86. module signallib
  87. implicit none
  88. <BLANKLINE>
  89. contains
  90. <BLANKLINE>
  91. <BLANKLINE>
  92. end module
  93. """
  94. __slots__ = ('name', 'declarations', 'definitions')
  95. defaults = {'declarations': Tuple()}
  96. _construct_name = String
  97. _construct_declarations = staticmethod(lambda arg: CodeBlock(*arg))
  98. _construct_definitions = staticmethod(lambda arg: CodeBlock(*arg))
  99. class Subroutine(Node):
  100. """ Represents a subroutine in Fortran.
  101. Examples
  102. ========
  103. >>> from sympy import fcode, symbols
  104. >>> from sympy.codegen.ast import Print
  105. >>> from sympy.codegen.fnodes import Subroutine
  106. >>> x, y = symbols('x y', real=True)
  107. >>> sub = Subroutine('mysub', [x, y], [Print([x**2 + y**2, x*y])])
  108. >>> print(fcode(sub, source_format='free', standard=2003))
  109. subroutine mysub(x, y)
  110. real*8 :: x
  111. real*8 :: y
  112. print *, x**2 + y**2, x*y
  113. end subroutine
  114. """
  115. __slots__ = ('name', 'parameters', 'body', 'attrs')
  116. _construct_name = String
  117. _construct_parameters = staticmethod(lambda params: Tuple(*map(Variable.deduced, params)))
  118. @classmethod
  119. def _construct_body(cls, itr):
  120. if isinstance(itr, CodeBlock):
  121. return itr
  122. else:
  123. return CodeBlock(*itr)
  124. class SubroutineCall(Token):
  125. """ Represents a call to a subroutine in Fortran.
  126. Examples
  127. ========
  128. >>> from sympy.codegen.fnodes import SubroutineCall
  129. >>> from sympy import fcode
  130. >>> fcode(SubroutineCall('mysub', 'x y'.split()))
  131. ' call mysub(x, y)'
  132. """
  133. __slots__ = ('name', 'subroutine_args')
  134. _construct_name = staticmethod(_name)
  135. _construct_subroutine_args = staticmethod(_mk_Tuple)
  136. class Do(Token):
  137. """ Represents a Do loop in in Fortran.
  138. Examples
  139. ========
  140. >>> from sympy import fcode, symbols
  141. >>> from sympy.codegen.ast import aug_assign, Print
  142. >>> from sympy.codegen.fnodes import Do
  143. >>> i, n = symbols('i n', integer=True)
  144. >>> r = symbols('r', real=True)
  145. >>> body = [aug_assign(r, '+', 1/i), Print([i, r])]
  146. >>> do1 = Do(body, i, 1, n)
  147. >>> print(fcode(do1, source_format='free'))
  148. do i = 1, n
  149. r = r + 1d0/i
  150. print *, i, r
  151. end do
  152. >>> do2 = Do(body, i, 1, n, 2)
  153. >>> print(fcode(do2, source_format='free'))
  154. do i = 1, n, 2
  155. r = r + 1d0/i
  156. print *, i, r
  157. end do
  158. """
  159. __slots__ = ('body', 'counter', 'first', 'last', 'step', 'concurrent')
  160. defaults = {'step': Integer(1), 'concurrent': false}
  161. _construct_body = staticmethod(lambda body: CodeBlock(*body))
  162. _construct_counter = staticmethod(sympify)
  163. _construct_first = staticmethod(sympify)
  164. _construct_last = staticmethod(sympify)
  165. _construct_step = staticmethod(sympify)
  166. _construct_concurrent = staticmethod(lambda arg: true if arg else false)
  167. class ArrayConstructor(Token):
  168. """ Represents an array constructor.
  169. Examples
  170. ========
  171. >>> from sympy import fcode
  172. >>> from sympy.codegen.fnodes import ArrayConstructor
  173. >>> ac = ArrayConstructor([1, 2, 3])
  174. >>> fcode(ac, standard=95, source_format='free')
  175. '(/1, 2, 3/)'
  176. >>> fcode(ac, standard=2003, source_format='free')
  177. '[1, 2, 3]'
  178. """
  179. __slots__ = ('elements',)
  180. _construct_elements = staticmethod(_mk_Tuple)
  181. class ImpliedDoLoop(Token):
  182. """ Represents an implied do loop in Fortran.
  183. Examples
  184. ========
  185. >>> from sympy import Symbol, fcode
  186. >>> from sympy.codegen.fnodes import ImpliedDoLoop, ArrayConstructor
  187. >>> i = Symbol('i', integer=True)
  188. >>> idl = ImpliedDoLoop(i**3, i, -3, 3, 2) # -27, -1, 1, 27
  189. >>> ac = ArrayConstructor([-28, idl, 28]) # -28, -27, -1, 1, 27, 28
  190. >>> fcode(ac, standard=2003, source_format='free')
  191. '[-28, (i**3, i = -3, 3, 2), 28]'
  192. """
  193. __slots__ = ('expr', 'counter', 'first', 'last', 'step')
  194. defaults = {'step': Integer(1)}
  195. _construct_expr = staticmethod(sympify)
  196. _construct_counter = staticmethod(sympify)
  197. _construct_first = staticmethod(sympify)
  198. _construct_last = staticmethod(sympify)
  199. _construct_step = staticmethod(sympify)
  200. class Extent(Basic):
  201. """ Represents a dimension extent.
  202. Examples
  203. ========
  204. >>> from sympy.codegen.fnodes import Extent
  205. >>> e = Extent(-3, 3) # -3, -2, -1, 0, 1, 2, 3
  206. >>> from sympy import fcode
  207. >>> fcode(e, source_format='free')
  208. '-3:3'
  209. >>> from sympy.codegen.ast import Variable, real
  210. >>> from sympy.codegen.fnodes import dimension, intent_out
  211. >>> dim = dimension(e, e)
  212. >>> arr = Variable('x', real, attrs=[dim, intent_out])
  213. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  214. 'real*8, dimension(-3:3, -3:3), intent(out) :: x'
  215. """
  216. def __new__(cls, *args):
  217. if len(args) == 2:
  218. low, high = args
  219. return Basic.__new__(cls, sympify(low), sympify(high))
  220. elif len(args) == 0 or (len(args) == 1 and args[0] in (':', None)):
  221. return Basic.__new__(cls) # assumed shape
  222. else:
  223. raise ValueError("Expected 0 or 2 args (or one argument == None or ':')")
  224. def _sympystr(self, printer):
  225. if len(self.args) == 0:
  226. return ':'
  227. return ":".join(str(arg) for arg in self.args)
  228. assumed_extent = Extent() # or Extent(':'), Extent(None)
  229. def dimension(*args):
  230. """ Creates a 'dimension' Attribute with (up to 7) extents.
  231. Examples
  232. ========
  233. >>> from sympy import fcode
  234. >>> from sympy.codegen.fnodes import dimension, intent_in
  235. >>> dim = dimension('2', ':') # 2 rows, runtime determined number of columns
  236. >>> from sympy.codegen.ast import Variable, integer
  237. >>> arr = Variable('a', integer, attrs=[dim, intent_in])
  238. >>> fcode(arr.as_Declaration(), source_format='free', standard=2003)
  239. 'integer*4, dimension(2, :), intent(in) :: a'
  240. """
  241. if len(args) > 7:
  242. raise ValueError("Fortran only supports up to 7 dimensional arrays")
  243. parameters = []
  244. for arg in args:
  245. if isinstance(arg, Extent):
  246. parameters.append(arg)
  247. elif isinstance(arg, str):
  248. if arg == ':':
  249. parameters.append(Extent())
  250. else:
  251. parameters.append(String(arg))
  252. elif iterable(arg):
  253. parameters.append(Extent(*arg))
  254. else:
  255. parameters.append(sympify(arg))
  256. if len(args) == 0:
  257. raise ValueError("Need at least one dimension")
  258. return Attribute('dimension', parameters)
  259. assumed_size = dimension('*')
  260. def array(symbol, dim, intent=None, *, attrs=(), value=None, type=None):
  261. """ Convenience function for creating a Variable instance for a Fortran array.
  262. Parameters
  263. ==========
  264. symbol : symbol
  265. dim : Attribute or iterable
  266. If dim is an ``Attribute`` it need to have the name 'dimension'. If it is
  267. not an ``Attribute``, then it is passsed to :func:`dimension` as ``*dim``
  268. intent : str
  269. One of: 'in', 'out', 'inout' or None
  270. \\*\\*kwargs:
  271. Keyword arguments for ``Variable`` ('type' & 'value')
  272. Examples
  273. ========
  274. >>> from sympy import fcode
  275. >>> from sympy.codegen.ast import integer, real
  276. >>> from sympy.codegen.fnodes import array
  277. >>> arr = array('a', '*', 'in', type=integer)
  278. >>> print(fcode(arr.as_Declaration(), source_format='free', standard=2003))
  279. integer*4, dimension(*), intent(in) :: a
  280. >>> x = array('x', [3, ':', ':'], intent='out', type=real)
  281. >>> print(fcode(x.as_Declaration(value=1), source_format='free', standard=2003))
  282. real*8, dimension(3, :, :), intent(out) :: x = 1
  283. """
  284. if isinstance(dim, Attribute):
  285. if str(dim.name) != 'dimension':
  286. raise ValueError("Got an unexpected Attribute argument as dim: %s" % str(dim))
  287. else:
  288. dim = dimension(*dim)
  289. attrs = list(attrs) + [dim]
  290. if intent is not None:
  291. if intent not in (intent_in, intent_out, intent_inout):
  292. intent = {'in': intent_in, 'out': intent_out, 'inout': intent_inout}[intent]
  293. attrs.append(intent)
  294. if type is None:
  295. return Variable.deduced(symbol, value=value, attrs=attrs)
  296. else:
  297. return Variable(symbol, type, value=value, attrs=attrs)
  298. def _printable(arg):
  299. return String(arg) if isinstance(arg, str) else sympify(arg)
  300. def allocated(array):
  301. """ Creates an AST node for a function call to Fortran's "allocated(...)"
  302. Examples
  303. ========
  304. >>> from sympy import fcode
  305. >>> from sympy.codegen.fnodes import allocated
  306. >>> alloc = allocated('x')
  307. >>> fcode(alloc, source_format='free')
  308. 'allocated(x)'
  309. """
  310. return FunctionCall('allocated', [_printable(array)])
  311. def lbound(array, dim=None, kind=None):
  312. """ Creates an AST node for a function call to Fortran's "lbound(...)"
  313. Parameters
  314. ==========
  315. array : Symbol or String
  316. dim : expr
  317. kind : expr
  318. Examples
  319. ========
  320. >>> from sympy import fcode
  321. >>> from sympy.codegen.fnodes import lbound
  322. >>> lb = lbound('arr', dim=2)
  323. >>> fcode(lb, source_format='free')
  324. 'lbound(arr, 2)'
  325. """
  326. return FunctionCall(
  327. 'lbound',
  328. [_printable(array)] +
  329. ([_printable(dim)] if dim else []) +
  330. ([_printable(kind)] if kind else [])
  331. )
  332. def ubound(array, dim=None, kind=None):
  333. return FunctionCall(
  334. 'ubound',
  335. [_printable(array)] +
  336. ([_printable(dim)] if dim else []) +
  337. ([_printable(kind)] if kind else [])
  338. )
  339. def shape(source, kind=None):
  340. """ Creates an AST node for a function call to Fortran's "shape(...)"
  341. Parameters
  342. ==========
  343. source : Symbol or String
  344. kind : expr
  345. Examples
  346. ========
  347. >>> from sympy import fcode
  348. >>> from sympy.codegen.fnodes import shape
  349. >>> shp = shape('x')
  350. >>> fcode(shp, source_format='free')
  351. 'shape(x)'
  352. """
  353. return FunctionCall(
  354. 'shape',
  355. [_printable(source)] +
  356. ([_printable(kind)] if kind else [])
  357. )
  358. def size(array, dim=None, kind=None):
  359. """ Creates an AST node for a function call to Fortran's "size(...)"
  360. Examples
  361. ========
  362. >>> from sympy import fcode, Symbol
  363. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  364. >>> from sympy.codegen.fnodes import array, sum_, size
  365. >>> a = Symbol('a', real=True)
  366. >>> body = [Return((sum_(a**2)/size(a))**.5)]
  367. >>> arr = array(a, dim=[':'], intent='in')
  368. >>> fd = FunctionDefinition(real, 'rms', [arr], body)
  369. >>> print(fcode(fd, source_format='free', standard=2003))
  370. real*8 function rms(a)
  371. real*8, dimension(:), intent(in) :: a
  372. rms = sqrt(sum(a**2)*1d0/size(a))
  373. end function
  374. """
  375. return FunctionCall(
  376. 'size',
  377. [_printable(array)] +
  378. ([_printable(dim)] if dim else []) +
  379. ([_printable(kind)] if kind else [])
  380. )
  381. def reshape(source, shape, pad=None, order=None):
  382. """ Creates an AST node for a function call to Fortran's "reshape(...)"
  383. Parameters
  384. ==========
  385. source : Symbol or String
  386. shape : ArrayExpr
  387. """
  388. return FunctionCall(
  389. 'reshape',
  390. [_printable(source), _printable(shape)] +
  391. ([_printable(pad)] if pad else []) +
  392. ([_printable(order)] if pad else [])
  393. )
  394. def bind_C(name=None):
  395. """ Creates an Attribute ``bind_C`` with a name.
  396. Parameters
  397. ==========
  398. name : str
  399. Examples
  400. ========
  401. >>> from sympy import fcode, Symbol
  402. >>> from sympy.codegen.ast import FunctionDefinition, real, Return
  403. >>> from sympy.codegen.fnodes import array, sum_, bind_C
  404. >>> a = Symbol('a', real=True)
  405. >>> s = Symbol('s', integer=True)
  406. >>> arr = array(a, dim=[s], intent='in')
  407. >>> body = [Return((sum_(a**2)/s)**.5)]
  408. >>> fd = FunctionDefinition(real, 'rms', [arr, s], body, attrs=[bind_C('rms')])
  409. >>> print(fcode(fd, source_format='free', standard=2003))
  410. real*8 function rms(a, s) bind(C, name="rms")
  411. real*8, dimension(s), intent(in) :: a
  412. integer*4 :: s
  413. rms = sqrt(sum(a**2)/s)
  414. end function
  415. """
  416. return Attribute('bind_C', [String(name)] if name else [])
  417. class GoTo(Token):
  418. """ Represents a goto statement in Fortran
  419. Examples
  420. ========
  421. >>> from sympy.codegen.fnodes import GoTo
  422. >>> go = GoTo([10, 20, 30], 'i')
  423. >>> from sympy import fcode
  424. >>> fcode(go, source_format='free')
  425. 'go to (10, 20, 30), i'
  426. """
  427. __slots__ = ('labels', 'expr')
  428. defaults = {'expr': none}
  429. _construct_labels = staticmethod(_mk_Tuple)
  430. _construct_expr = staticmethod(sympify)
  431. class FortranReturn(Token):
  432. """ AST node explicitly mapped to a fortran "return".
  433. Explanation
  434. ===========
  435. Because a return statement in fortran is different from C, and
  436. in order to aid reuse of our codegen ASTs the ordinary
  437. ``.codegen.ast.Return`` is interpreted as assignment to
  438. the result variable of the function. If one for some reason needs
  439. to generate a fortran RETURN statement, this node should be used.
  440. Examples
  441. ========
  442. >>> from sympy.codegen.fnodes import FortranReturn
  443. >>> from sympy import fcode
  444. >>> fcode(FortranReturn('x'))
  445. ' return x'
  446. """
  447. __slots__ = ('return_value',)
  448. defaults = {'return_value': none}
  449. _construct_return_value = staticmethod(sympify)
  450. class FFunction(Function):
  451. _required_standard = 77
  452. def _fcode(self, printer):
  453. name = self.__class__.__name__
  454. if printer._settings['standard'] < self._required_standard:
  455. raise NotImplementedError("%s requires Fortran %d or newer" %
  456. (name, self._required_standard))
  457. return '{}({})'.format(name, ', '.join(map(printer._print, self.args)))
  458. class F95Function(FFunction):
  459. _required_standard = 95
  460. class isign(FFunction):
  461. """ Fortran sign intrinsic for integer arguments. """
  462. nargs = 2
  463. class dsign(FFunction):
  464. """ Fortran sign intrinsic for double precision arguments. """
  465. nargs = 2
  466. class cmplx(FFunction):
  467. """ Fortran complex conversion function. """
  468. nargs = 2 # may be extended to (2, 3) at a later point
  469. class kind(FFunction):
  470. """ Fortran kind function. """
  471. nargs = 1
  472. class merge(F95Function):
  473. """ Fortran merge function """
  474. nargs = 3
  475. class _literal(Float):
  476. _token = None # type: str
  477. _decimals = None # type: int
  478. def _fcode(self, printer, *args, **kwargs):
  479. mantissa, sgnd_ex = ('%.{}e'.format(self._decimals) % self).split('e')
  480. mantissa = mantissa.strip('0').rstrip('.')
  481. ex_sgn, ex_num = sgnd_ex[0], sgnd_ex[1:].lstrip('0')
  482. ex_sgn = '' if ex_sgn == '+' else ex_sgn
  483. return (mantissa or '0') + self._token + ex_sgn + (ex_num or '0')
  484. class literal_sp(_literal):
  485. """ Fortran single precision real literal """
  486. _token = 'e'
  487. _decimals = 9
  488. class literal_dp(_literal):
  489. """ Fortran double precision real literal """
  490. _token = 'd'
  491. _decimals = 17
  492. class sum_(Token, Expr):
  493. __slots__ = ('array', 'dim', 'mask')
  494. defaults = {'dim': none, 'mask': none}
  495. _construct_array = staticmethod(sympify)
  496. _construct_dim = staticmethod(sympify)
  497. class product_(Token, Expr):
  498. __slots__ = ('array', 'dim', 'mask')
  499. defaults = {'dim': none, 'mask': none}
  500. _construct_array = staticmethod(sympify)
  501. _construct_dim = staticmethod(sympify)