123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775 |
- """
- Fortran code printer
- The FCodePrinter converts single SymPy expressions into single Fortran
- expressions, using the functions defined in the Fortran 77 standard where
- possible. Some useful pointers to Fortran can be found on wikipedia:
- https://en.wikipedia.org/wiki/Fortran
- Most of the code below is based on the "Professional Programmer\'s Guide to
- Fortran77" by Clive G. Page:
- http://www.star.le.ac.uk/~cgp/prof77.html
- Fortran is a case-insensitive language. This might cause trouble because
- SymPy is case sensitive. So, fcode adds underscores to variable names when
- it is necessary to make them different for Fortran.
- """
- from typing import Dict as tDict, Any
- from collections import defaultdict
- from itertools import chain
- import string
- from sympy.codegen.ast import (
- Assignment, Declaration, Pointer, value_const,
- float32, float64, float80, complex64, complex128, int8, int16, int32,
- int64, intc, real, integer, bool_, complex_
- )
- from sympy.codegen.fnodes import (
- allocatable, isign, dsign, cmplx, merge, literal_dp, elemental, pure,
- intent_in, intent_out, intent_inout
- )
- from sympy.core import S, Add, N, Float, Symbol
- from sympy.core.function import Function
- from sympy.core.relational import Eq
- from sympy.sets import Range
- from sympy.printing.codeprinter import CodePrinter
- from sympy.printing.precedence import precedence, PRECEDENCE
- from sympy.printing.printer import printer_context
- # These are defined in the other file so we can avoid importing sympy.codegen
- # from the top-level 'import sympy'. Export them here as well.
- from sympy.printing.codeprinter import fcode, print_fcode # noqa:F401
- known_functions = {
- "sin": "sin",
- "cos": "cos",
- "tan": "tan",
- "asin": "asin",
- "acos": "acos",
- "atan": "atan",
- "atan2": "atan2",
- "sinh": "sinh",
- "cosh": "cosh",
- "tanh": "tanh",
- "log": "log",
- "exp": "exp",
- "erf": "erf",
- "Abs": "abs",
- "conjugate": "conjg",
- "Max": "max",
- "Min": "min",
- }
- class FCodePrinter(CodePrinter):
- """A printer to convert SymPy expressions to strings of Fortran code"""
- printmethod = "_fcode"
- language = "Fortran"
- type_aliases = {
- integer: int32,
- real: float64,
- complex_: complex128,
- }
- type_mappings = {
- intc: 'integer(c_int)',
- float32: 'real*4', # real(kind(0.e0))
- float64: 'real*8', # real(kind(0.d0))
- float80: 'real*10', # real(kind(????))
- complex64: 'complex*8',
- complex128: 'complex*16',
- int8: 'integer*1',
- int16: 'integer*2',
- int32: 'integer*4',
- int64: 'integer*8',
- bool_: 'logical'
- }
- type_modules = {
- intc: {'iso_c_binding': 'c_int'}
- }
- _default_settings = {
- 'order': None,
- 'full_prec': 'auto',
- 'precision': 17,
- 'user_functions': {},
- 'human': True,
- 'allow_unknown_functions': False,
- 'source_format': 'fixed',
- 'contract': True,
- 'standard': 77,
- 'name_mangling' : True,
- } # type: tDict[str, Any]
- _operators = {
- 'and': '.and.',
- 'or': '.or.',
- 'xor': '.neqv.',
- 'equivalent': '.eqv.',
- 'not': '.not. ',
- }
- _relationals = {
- '!=': '/=',
- }
- def __init__(self, settings=None):
- if not settings:
- settings = {}
- self.mangled_symbols = {} # Dict showing mapping of all words
- self.used_name = []
- self.type_aliases = dict(chain(self.type_aliases.items(),
- settings.pop('type_aliases', {}).items()))
- self.type_mappings = dict(chain(self.type_mappings.items(),
- settings.pop('type_mappings', {}).items()))
- super().__init__(settings)
- self.known_functions = dict(known_functions)
- userfuncs = settings.get('user_functions', {})
- self.known_functions.update(userfuncs)
- # leading columns depend on fixed or free format
- standards = {66, 77, 90, 95, 2003, 2008}
- if self._settings['standard'] not in standards:
- raise ValueError("Unknown Fortran standard: %s" % self._settings[
- 'standard'])
- self.module_uses = defaultdict(set) # e.g.: use iso_c_binding, only: c_int
- @property
- def _lead(self):
- if self._settings['source_format'] == 'fixed':
- return {'code': " ", 'cont': " @ ", 'comment': "C "}
- elif self._settings['source_format'] == 'free':
- return {'code': "", 'cont': " ", 'comment': "! "}
- else:
- raise ValueError("Unknown source format: %s" % self._settings['source_format'])
- def _print_Symbol(self, expr):
- if self._settings['name_mangling'] == True:
- if expr not in self.mangled_symbols:
- name = expr.name
- while name.lower() in self.used_name:
- name += '_'
- self.used_name.append(name.lower())
- if name == expr.name:
- self.mangled_symbols[expr] = expr
- else:
- self.mangled_symbols[expr] = Symbol(name)
- expr = expr.xreplace(self.mangled_symbols)
- name = super()._print_Symbol(expr)
- return name
- def _rate_index_position(self, p):
- return -p*5
- def _get_statement(self, codestring):
- return codestring
- def _get_comment(self, text):
- return "! {}".format(text)
- def _declare_number_const(self, name, value):
- return "parameter ({} = {})".format(name, self._print(value))
- def _print_NumberSymbol(self, expr):
- # A Number symbol that is not implemented here or with _printmethod
- # is registered and evaluated
- self._number_symbols.add((expr, Float(expr.evalf(self._settings['precision']))))
- return str(expr)
- def _format_code(self, lines):
- return self._wrap_fortran(self.indent_code(lines))
- def _traverse_matrix_indices(self, mat):
- rows, cols = mat.shape
- return ((i, j) for j in range(cols) for i in range(rows))
- def _get_loop_opening_ending(self, indices):
- open_lines = []
- close_lines = []
- for i in indices:
- # fortran arrays start at 1 and end at dimension
- var, start, stop = map(self._print,
- [i.label, i.lower + 1, i.upper + 1])
- open_lines.append("do %s = %s, %s" % (var, start, stop))
- close_lines.append("end do")
- return open_lines, close_lines
- def _print_sign(self, expr):
- from sympy.functions.elementary.complexes import Abs
- arg, = expr.args
- if arg.is_integer:
- new_expr = merge(0, isign(1, arg), Eq(arg, 0))
- elif (arg.is_complex or arg.is_infinite):
- new_expr = merge(cmplx(literal_dp(0), literal_dp(0)), arg/Abs(arg), Eq(Abs(arg), literal_dp(0)))
- else:
- new_expr = merge(literal_dp(0), dsign(literal_dp(1), arg), Eq(arg, literal_dp(0)))
- return self._print(new_expr)
- def _print_Piecewise(self, expr):
- if expr.args[-1].cond != True:
- # We need the last conditional to be a True, otherwise the resulting
- # function may not return a result.
- raise ValueError("All Piecewise expressions must contain an "
- "(expr, True) statement to be used as a default "
- "condition. Without one, the generated "
- "expression may not evaluate to anything under "
- "some condition.")
- lines = []
- if expr.has(Assignment):
- for i, (e, c) in enumerate(expr.args):
- if i == 0:
- lines.append("if (%s) then" % self._print(c))
- elif i == len(expr.args) - 1 and c == True:
- lines.append("else")
- else:
- lines.append("else if (%s) then" % self._print(c))
- lines.append(self._print(e))
- lines.append("end if")
- return "\n".join(lines)
- elif self._settings["standard"] >= 95:
- # Only supported in F95 and newer:
- # The piecewise was used in an expression, need to do inline
- # operators. This has the downside that inline operators will
- # not work for statements that span multiple lines (Matrix or
- # Indexed expressions).
- pattern = "merge({T}, {F}, {COND})"
- code = self._print(expr.args[-1].expr)
- terms = list(expr.args[:-1])
- while terms:
- e, c = terms.pop()
- expr = self._print(e)
- cond = self._print(c)
- code = pattern.format(T=expr, F=code, COND=cond)
- return code
- else:
- # `merge` is not supported prior to F95
- raise NotImplementedError("Using Piecewise as an expression using "
- "inline operators is not supported in "
- "standards earlier than Fortran95.")
- def _print_MatrixElement(self, expr):
- return "{}({}, {})".format(self.parenthesize(expr.parent,
- PRECEDENCE["Atom"], strict=True), expr.i + 1, expr.j + 1)
- def _print_Add(self, expr):
- # purpose: print complex numbers nicely in Fortran.
- # collect the purely real and purely imaginary parts:
- pure_real = []
- pure_imaginary = []
- mixed = []
- for arg in expr.args:
- if arg.is_number and arg.is_real:
- pure_real.append(arg)
- elif arg.is_number and arg.is_imaginary:
- pure_imaginary.append(arg)
- else:
- mixed.append(arg)
- if pure_imaginary:
- if mixed:
- PREC = precedence(expr)
- term = Add(*mixed)
- t = self._print(term)
- if t.startswith('-'):
- sign = "-"
- t = t[1:]
- else:
- sign = "+"
- if precedence(term) < PREC:
- t = "(%s)" % t
- return "cmplx(%s,%s) %s %s" % (
- self._print(Add(*pure_real)),
- self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
- sign, t,
- )
- else:
- return "cmplx(%s,%s)" % (
- self._print(Add(*pure_real)),
- self._print(-S.ImaginaryUnit*Add(*pure_imaginary)),
- )
- else:
- return CodePrinter._print_Add(self, expr)
- def _print_Function(self, expr):
- # All constant function args are evaluated as floats
- prec = self._settings['precision']
- args = [N(a, prec) for a in expr.args]
- eval_expr = expr.func(*args)
- if not isinstance(eval_expr, Function):
- return self._print(eval_expr)
- else:
- return CodePrinter._print_Function(self, expr.func(*args))
- def _print_Mod(self, expr):
- # NOTE : Fortran has the functions mod() and modulo(). modulo() behaves
- # the same wrt to the sign of the arguments as Python and SymPy's
- # modulus computations (% and Mod()) but is not available in Fortran 66
- # or Fortran 77, thus we raise an error.
- if self._settings['standard'] in [66, 77]:
- msg = ("Python % operator and SymPy's Mod() function are not "
- "supported by Fortran 66 or 77 standards.")
- raise NotImplementedError(msg)
- else:
- x, y = expr.args
- return " modulo({}, {})".format(self._print(x), self._print(y))
- def _print_ImaginaryUnit(self, expr):
- # purpose: print complex numbers nicely in Fortran.
- return "cmplx(0,1)"
- def _print_int(self, expr):
- return str(expr)
- def _print_Mul(self, expr):
- # purpose: print complex numbers nicely in Fortran.
- if expr.is_number and expr.is_imaginary:
- return "cmplx(0,%s)" % (
- self._print(-S.ImaginaryUnit*expr)
- )
- else:
- return CodePrinter._print_Mul(self, expr)
- def _print_Pow(self, expr):
- PREC = precedence(expr)
- if expr.exp == -1:
- return '%s/%s' % (
- self._print(literal_dp(1)),
- self.parenthesize(expr.base, PREC)
- )
- elif expr.exp == 0.5:
- if expr.base.is_integer:
- # Fortran intrinsic sqrt() does not accept integer argument
- if expr.base.is_Number:
- return 'sqrt(%s.0d0)' % self._print(expr.base)
- else:
- return 'sqrt(dble(%s))' % self._print(expr.base)
- else:
- return 'sqrt(%s)' % self._print(expr.base)
- else:
- return CodePrinter._print_Pow(self, expr)
- def _print_Rational(self, expr):
- p, q = int(expr.p), int(expr.q)
- return "%d.0d0/%d.0d0" % (p, q)
- def _print_Float(self, expr):
- printed = CodePrinter._print_Float(self, expr)
- e = printed.find('e')
- if e > -1:
- return "%sd%s" % (printed[:e], printed[e + 1:])
- return "%sd0" % printed
- def _print_Relational(self, expr):
- lhs_code = self._print(expr.lhs)
- rhs_code = self._print(expr.rhs)
- op = expr.rel_op
- op = op if op not in self._relationals else self._relationals[op]
- return "{} {} {}".format(lhs_code, op, rhs_code)
- def _print_Indexed(self, expr):
- inds = [ self._print(i) for i in expr.indices ]
- return "%s(%s)" % (self._print(expr.base.label), ", ".join(inds))
- def _print_Idx(self, expr):
- return self._print(expr.label)
- def _print_AugmentedAssignment(self, expr):
- lhs_code = self._print(expr.lhs)
- rhs_code = self._print(expr.rhs)
- return self._get_statement("{0} = {0} {1} {2}".format(
- *map(lambda arg: self._print(arg),
- [lhs_code, expr.binop, rhs_code])))
- def _print_sum_(self, sm):
- params = self._print(sm.array)
- if sm.dim != None: # Must use '!= None', cannot use 'is not None'
- params += ', ' + self._print(sm.dim)
- if sm.mask != None: # Must use '!= None', cannot use 'is not None'
- params += ', mask=' + self._print(sm.mask)
- return '%s(%s)' % (sm.__class__.__name__.rstrip('_'), params)
- def _print_product_(self, prod):
- return self._print_sum_(prod)
- def _print_Do(self, do):
- excl = ['concurrent']
- if do.step == 1:
- excl.append('step')
- step = ''
- else:
- step = ', {step}'
- return (
- 'do {concurrent}{counter} = {first}, {last}'+step+'\n'
- '{body}\n'
- 'end do\n'
- ).format(
- concurrent='concurrent ' if do.concurrent else '',
- **do.kwargs(apply=lambda arg: self._print(arg), exclude=excl)
- )
- def _print_ImpliedDoLoop(self, idl):
- step = '' if idl.step == 1 else ', {step}'
- return ('({expr}, {counter} = {first}, {last}'+step+')').format(
- **idl.kwargs(apply=lambda arg: self._print(arg))
- )
- def _print_For(self, expr):
- target = self._print(expr.target)
- if isinstance(expr.iterable, Range):
- start, stop, step = expr.iterable.args
- else:
- raise NotImplementedError("Only iterable currently supported is Range")
- body = self._print(expr.body)
- return ('do {target} = {start}, {stop}, {step}\n'
- '{body}\n'
- 'end do').format(target=target, start=start, stop=stop,
- step=step, body=body)
- def _print_Type(self, type_):
- type_ = self.type_aliases.get(type_, type_)
- type_str = self.type_mappings.get(type_, type_.name)
- module_uses = self.type_modules.get(type_)
- if module_uses:
- for k, v in module_uses:
- self.module_uses[k].add(v)
- return type_str
- def _print_Element(self, elem):
- return '{symbol}({idxs})'.format(
- symbol=self._print(elem.symbol),
- idxs=', '.join(map(lambda arg: self._print(arg), elem.indices))
- )
- def _print_Extent(self, ext):
- return str(ext)
- def _print_Declaration(self, expr):
- var = expr.variable
- val = var.value
- dim = var.attr_params('dimension')
- intents = [intent in var.attrs for intent in (intent_in, intent_out, intent_inout)]
- if intents.count(True) == 0:
- intent = ''
- elif intents.count(True) == 1:
- intent = ', intent(%s)' % ['in', 'out', 'inout'][intents.index(True)]
- else:
- raise ValueError("Multiple intents specified for %s" % self)
- if isinstance(var, Pointer):
- raise NotImplementedError("Pointers are not available by default in Fortran.")
- if self._settings["standard"] >= 90:
- result = '{t}{vc}{dim}{intent}{alloc} :: {s}'.format(
- t=self._print(var.type),
- vc=', parameter' if value_const in var.attrs else '',
- dim=', dimension(%s)' % ', '.join(map(lambda arg: self._print(arg), dim)) if dim else '',
- intent=intent,
- alloc=', allocatable' if allocatable in var.attrs else '',
- s=self._print(var.symbol)
- )
- if val != None: # Must be "!= None", cannot be "is not None"
- result += ' = %s' % self._print(val)
- else:
- if value_const in var.attrs or val:
- raise NotImplementedError("F77 init./parameter statem. req. multiple lines.")
- result = ' '.join(map(lambda arg: self._print(arg), [var.type, var.symbol]))
- return result
- def _print_Infinity(self, expr):
- return '(huge(%s) + 1)' % self._print(literal_dp(0))
- def _print_While(self, expr):
- return 'do while ({condition})\n{body}\nend do'.format(**expr.kwargs(
- apply=lambda arg: self._print(arg)))
- def _print_BooleanTrue(self, expr):
- return '.true.'
- def _print_BooleanFalse(self, expr):
- return '.false.'
- def _pad_leading_columns(self, lines):
- result = []
- for line in lines:
- if line.startswith('!'):
- result.append(self._lead['comment'] + line[1:].lstrip())
- else:
- result.append(self._lead['code'] + line)
- return result
- def _wrap_fortran(self, lines):
- """Wrap long Fortran lines
- Argument:
- lines -- a list of lines (without \\n character)
- A comment line is split at white space. Code lines are split with a more
- complex rule to give nice results.
- """
- # routine to find split point in a code line
- my_alnum = set("_+-." + string.digits + string.ascii_letters)
- my_white = set(" \t()")
- def split_pos_code(line, endpos):
- if len(line) <= endpos:
- return len(line)
- pos = endpos
- split = lambda pos: \
- (line[pos] in my_alnum and line[pos - 1] not in my_alnum) or \
- (line[pos] not in my_alnum and line[pos - 1] in my_alnum) or \
- (line[pos] in my_white and line[pos - 1] not in my_white) or \
- (line[pos] not in my_white and line[pos - 1] in my_white)
- while not split(pos):
- pos -= 1
- if pos == 0:
- return endpos
- return pos
- # split line by line and add the split lines to result
- result = []
- if self._settings['source_format'] == 'free':
- trailing = ' &'
- else:
- trailing = ''
- for line in lines:
- if line.startswith(self._lead['comment']):
- # comment line
- if len(line) > 72:
- pos = line.rfind(" ", 6, 72)
- if pos == -1:
- pos = 72
- hunk = line[:pos]
- line = line[pos:].lstrip()
- result.append(hunk)
- while line:
- pos = line.rfind(" ", 0, 66)
- if pos == -1 or len(line) < 66:
- pos = 66
- hunk = line[:pos]
- line = line[pos:].lstrip()
- result.append("%s%s" % (self._lead['comment'], hunk))
- else:
- result.append(line)
- elif line.startswith(self._lead['code']):
- # code line
- pos = split_pos_code(line, 72)
- hunk = line[:pos].rstrip()
- line = line[pos:].lstrip()
- if line:
- hunk += trailing
- result.append(hunk)
- while line:
- pos = split_pos_code(line, 65)
- hunk = line[:pos].rstrip()
- line = line[pos:].lstrip()
- if line:
- hunk += trailing
- result.append("%s%s" % (self._lead['cont'], hunk))
- else:
- result.append(line)
- return result
- def indent_code(self, code):
- """Accepts a string of code or a list of code lines"""
- if isinstance(code, str):
- code_lines = self.indent_code(code.splitlines(True))
- return ''.join(code_lines)
- free = self._settings['source_format'] == 'free'
- code = [ line.lstrip(' \t') for line in code ]
- inc_keyword = ('do ', 'if(', 'if ', 'do\n', 'else', 'program', 'interface')
- dec_keyword = ('end do', 'enddo', 'end if', 'endif', 'else', 'end program', 'end interface')
- increase = [ int(any(map(line.startswith, inc_keyword)))
- for line in code ]
- decrease = [ int(any(map(line.startswith, dec_keyword)))
- for line in code ]
- continuation = [ int(any(map(line.endswith, ['&', '&\n'])))
- for line in code ]
- level = 0
- cont_padding = 0
- tabwidth = 3
- new_code = []
- for i, line in enumerate(code):
- if line in ('', '\n'):
- new_code.append(line)
- continue
- level -= decrease[i]
- if free:
- padding = " "*(level*tabwidth + cont_padding)
- else:
- padding = " "*level*tabwidth
- line = "%s%s" % (padding, line)
- if not free:
- line = self._pad_leading_columns([line])[0]
- new_code.append(line)
- if continuation[i]:
- cont_padding = 2*tabwidth
- else:
- cont_padding = 0
- level += increase[i]
- if not free:
- return self._wrap_fortran(new_code)
- return new_code
- def _print_GoTo(self, goto):
- if goto.expr: # computed goto
- return "go to ({labels}), {expr}".format(
- labels=', '.join(map(lambda arg: self._print(arg), goto.labels)),
- expr=self._print(goto.expr)
- )
- else:
- lbl, = goto.labels
- return "go to %s" % self._print(lbl)
- def _print_Program(self, prog):
- return (
- "program {name}\n"
- "{body}\n"
- "end program\n"
- ).format(**prog.kwargs(apply=lambda arg: self._print(arg)))
- def _print_Module(self, mod):
- return (
- "module {name}\n"
- "{declarations}\n"
- "\ncontains\n\n"
- "{definitions}\n"
- "end module\n"
- ).format(**mod.kwargs(apply=lambda arg: self._print(arg)))
- def _print_Stream(self, strm):
- if strm.name == 'stdout' and self._settings["standard"] >= 2003:
- self.module_uses['iso_c_binding'].add('stdint=>input_unit')
- return 'input_unit'
- elif strm.name == 'stderr' and self._settings["standard"] >= 2003:
- self.module_uses['iso_c_binding'].add('stdint=>error_unit')
- return 'error_unit'
- else:
- if strm.name == 'stdout':
- return '*'
- else:
- return strm.name
- def _print_Print(self, ps):
- if ps.format_string != None: # Must be '!= None', cannot be 'is not None'
- fmt = self._print(ps.format_string)
- else:
- fmt = "*"
- return "print {fmt}, {iolist}".format(fmt=fmt, iolist=', '.join(
- map(lambda arg: self._print(arg), ps.print_args)))
- def _print_Return(self, rs):
- arg, = rs.args
- return "{result_name} = {arg}".format(
- result_name=self._context.get('result_name', 'sympy_result'),
- arg=self._print(arg)
- )
- def _print_FortranReturn(self, frs):
- arg, = frs.args
- if arg:
- return 'return %s' % self._print(arg)
- else:
- return 'return'
- def _head(self, entity, fp, **kwargs):
- bind_C_params = fp.attr_params('bind_C')
- if bind_C_params is None:
- bind = ''
- else:
- bind = ' bind(C, name="%s")' % bind_C_params[0] if bind_C_params else ' bind(C)'
- result_name = self._settings.get('result_name', None)
- return (
- "{entity}{name}({arg_names}){result}{bind}\n"
- "{arg_declarations}"
- ).format(
- entity=entity,
- name=self._print(fp.name),
- arg_names=', '.join([self._print(arg.symbol) for arg in fp.parameters]),
- result=(' result(%s)' % result_name) if result_name else '',
- bind=bind,
- arg_declarations='\n'.join(map(lambda arg: self._print(Declaration(arg)), fp.parameters))
- )
- def _print_FunctionPrototype(self, fp):
- entity = "{} function ".format(self._print(fp.return_type))
- return (
- "interface\n"
- "{function_head}\n"
- "end function\n"
- "end interface"
- ).format(function_head=self._head(entity, fp))
- def _print_FunctionDefinition(self, fd):
- if elemental in fd.attrs:
- prefix = 'elemental '
- elif pure in fd.attrs:
- prefix = 'pure '
- else:
- prefix = ''
- entity = "{} function ".format(self._print(fd.return_type))
- with printer_context(self, result_name=fd.name):
- return (
- "{prefix}{function_head}\n"
- "{body}\n"
- "end function\n"
- ).format(
- prefix=prefix,
- function_head=self._head(entity, fd),
- body=self._print(fd.body)
- )
- def _print_Subroutine(self, sub):
- return (
- '{subroutine_head}\n'
- '{body}\n'
- 'end subroutine\n'
- ).format(
- subroutine_head=self._head('subroutine ', sub),
- body=self._print(sub.body)
- )
- def _print_SubroutineCall(self, scall):
- return 'call {name}({args})'.format(
- name=self._print(scall.name),
- args=', '.join(map(lambda arg: self._print(arg), scall.subroutine_args))
- )
- def _print_use_rename(self, rnm):
- return "%s => %s" % tuple(map(lambda arg: self._print(arg), rnm.args))
- def _print_use(self, use):
- result = 'use %s' % self._print(use.namespace)
- if use.rename != None: # Must be '!= None', cannot be 'is not None'
- result += ', ' + ', '.join([self._print(rnm) for rnm in use.rename])
- if use.only != None: # Must be '!= None', cannot be 'is not None'
- result += ', only: ' + ', '.join([self._print(nly) for nly in use.only])
- return result
- def _print_BreakToken(self, _):
- return 'exit'
- def _print_ContinueToken(self, _):
- return 'cycle'
- def _print_ArrayConstructor(self, ac):
- fmtstr = "[%s]" if self._settings["standard"] >= 2003 else '(/%s/)'
- return fmtstr % ', '.join(map(lambda arg: self._print(arg), ac.elements))
|