123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748 |
- """
- C code printer
- The C89CodePrinter & C99CodePrinter converts single SymPy expressions into
- single C expressions, using the functions defined in math.h where possible.
- A complete code generator, which uses ccode extensively, can be found in
- sympy.utilities.codegen. The codegen module can be used to generate complete
- source code files that are compilable without further modifications.
- """
- from typing import Any, Dict as tDict, Tuple as tTuple
- from functools import wraps
- from itertools import chain
- from sympy.core import S
- from sympy.codegen.ast import (
- Assignment, Pointer, Variable, Declaration, Type,
- real, complex_, integer, bool_, float32, float64, float80,
- complex64, complex128, intc, value_const, pointer_const,
- int8, int16, int32, int64, uint8, uint16, uint32, uint64, untyped,
- none
- )
- from sympy.printing.codeprinter import CodePrinter, requires
- from sympy.printing.precedence import precedence, PRECEDENCE
- from sympy.sets.fancysets import Range
- # These are defined in the other file so we can avoid importing sympy.codegen
- # from the top-level 'import sympy'. Export them here as well.
- from sympy.printing.codeprinter import ccode, print_ccode # noqa:F401
- # dictionary mapping SymPy function to (argument_conditions, C_function).
- # Used in C89CodePrinter._print_Function(self)
- known_functions_C89 = {
- "Abs": [(lambda x: not x.is_integer, "fabs"), (lambda x: x.is_integer, "abs")],
- "sin": "sin",
- "cos": "cos",
- "tan": "tan",
- "asin": "asin",
- "acos": "acos",
- "atan": "atan",
- "atan2": "atan2",
- "exp": "exp",
- "log": "log",
- "sinh": "sinh",
- "cosh": "cosh",
- "tanh": "tanh",
- "floor": "floor",
- "ceiling": "ceil",
- "sqrt": "sqrt", # To enable automatic rewrites
- }
- known_functions_C99 = dict(known_functions_C89, **{
- 'exp2': 'exp2',
- 'expm1': 'expm1',
- 'log10': 'log10',
- 'log2': 'log2',
- 'log1p': 'log1p',
- 'Cbrt': 'cbrt',
- 'hypot': 'hypot',
- 'fma': 'fma',
- 'loggamma': 'lgamma',
- 'erfc': 'erfc',
- 'Max': 'fmax',
- 'Min': 'fmin',
- "asinh": "asinh",
- "acosh": "acosh",
- "atanh": "atanh",
- "erf": "erf",
- "gamma": "tgamma",
- })
- # These are the core reserved words in the C language. Taken from:
- # http://en.cppreference.com/w/c/keyword
- reserved_words = [
- 'auto', 'break', 'case', 'char', 'const', 'continue', 'default', 'do',
- 'double', 'else', 'enum', 'extern', 'float', 'for', 'goto', 'if', 'int',
- 'long', 'register', 'return', 'short', 'signed', 'sizeof', 'static',
- 'struct', 'entry', # never standardized, we'll leave it here anyway
- 'switch', 'typedef', 'union', 'unsigned', 'void', 'volatile', 'while'
- ]
- reserved_words_c99 = ['inline', 'restrict']
- def get_math_macros():
- """ Returns a dictionary with math-related macros from math.h/cmath
- Note that these macros are not strictly required by the C/C++-standard.
- For MSVC they are enabled by defining "_USE_MATH_DEFINES" (preferably
- via a compilation flag).
- Returns
- =======
- Dictionary mapping SymPy expressions to strings (macro names)
- """
- from sympy.codegen.cfunctions import log2, Sqrt
- from sympy.functions.elementary.exponential import log
- from sympy.functions.elementary.miscellaneous import sqrt
- return {
- S.Exp1: 'M_E',
- log2(S.Exp1): 'M_LOG2E',
- 1/log(2): 'M_LOG2E',
- log(2): 'M_LN2',
- log(10): 'M_LN10',
- S.Pi: 'M_PI',
- S.Pi/2: 'M_PI_2',
- S.Pi/4: 'M_PI_4',
- 1/S.Pi: 'M_1_PI',
- 2/S.Pi: 'M_2_PI',
- 2/sqrt(S.Pi): 'M_2_SQRTPI',
- 2/Sqrt(S.Pi): 'M_2_SQRTPI',
- sqrt(2): 'M_SQRT2',
- Sqrt(2): 'M_SQRT2',
- 1/sqrt(2): 'M_SQRT1_2',
- 1/Sqrt(2): 'M_SQRT1_2'
- }
- def _as_macro_if_defined(meth):
- """ Decorator for printer methods
- When a Printer's method is decorated using this decorator the expressions printed
- will first be looked for in the attribute ``math_macros``, and if present it will
- print the macro name in ``math_macros`` followed by a type suffix for the type
- ``real``. e.g. printing ``sympy.pi`` would print ``M_PIl`` if real is mapped to float80.
- """
- @wraps(meth)
- def _meth_wrapper(self, expr, **kwargs):
- if expr in self.math_macros:
- return '%s%s' % (self.math_macros[expr], self._get_math_macro_suffix(real))
- else:
- return meth(self, expr, **kwargs)
- return _meth_wrapper
- class C89CodePrinter(CodePrinter):
- """A printer to convert Python expressions to strings of C code"""
- printmethod = "_ccode"
- language = "C"
- standard = "C89"
- reserved_words = set(reserved_words)
- _default_settings = {
- 'order': None,
- 'full_prec': 'auto',
- 'precision': 17,
- 'user_functions': {},
- 'human': True,
- 'allow_unknown_functions': False,
- 'contract': True,
- 'dereference': set(),
- 'error_on_reserved': False,
- 'reserved_word_suffix': '_',
- } # type: tDict[str, Any]
- type_aliases = {
- real: float64,
- complex_: complex128,
- integer: intc
- }
- type_mappings = {
- real: 'double',
- intc: 'int',
- float32: 'float',
- float64: 'double',
- integer: 'int',
- bool_: 'bool',
- int8: 'int8_t',
- int16: 'int16_t',
- int32: 'int32_t',
- int64: 'int64_t',
- uint8: 'int8_t',
- uint16: 'int16_t',
- uint32: 'int32_t',
- uint64: 'int64_t',
- } # type: tDict[Type, Any]
- type_headers = {
- bool_: {'stdbool.h'},
- int8: {'stdint.h'},
- int16: {'stdint.h'},
- int32: {'stdint.h'},
- int64: {'stdint.h'},
- uint8: {'stdint.h'},
- uint16: {'stdint.h'},
- uint32: {'stdint.h'},
- uint64: {'stdint.h'},
- }
- # Macros needed to be defined when using a Type
- type_macros = {} # type: tDict[Type, tTuple[str, ...]]
- type_func_suffixes = {
- float32: 'f',
- float64: '',
- float80: 'l'
- }
- type_literal_suffixes = {
- float32: 'F',
- float64: '',
- float80: 'L'
- }
- type_math_macro_suffixes = {
- float80: 'l'
- }
- math_macros = None
- _ns = '' # namespace, C++ uses 'std::'
- # known_functions-dict to copy
- _kf = known_functions_C89 # type: tDict[str, Any]
- def __init__(self, settings=None):
- settings = settings or {}
- if self.math_macros is None:
- self.math_macros = settings.pop('math_macros', get_math_macros())
- self.type_aliases = dict(chain(self.type_aliases.items(),
- settings.pop('type_aliases', {}).items()))
- self.type_mappings = dict(chain(self.type_mappings.items(),
- settings.pop('type_mappings', {}).items()))
- self.type_headers = dict(chain(self.type_headers.items(),
- settings.pop('type_headers', {}).items()))
- self.type_macros = dict(chain(self.type_macros.items(),
- settings.pop('type_macros', {}).items()))
- self.type_func_suffixes = dict(chain(self.type_func_suffixes.items(),
- settings.pop('type_func_suffixes', {}).items()))
- self.type_literal_suffixes = dict(chain(self.type_literal_suffixes.items(),
- settings.pop('type_literal_suffixes', {}).items()))
- self.type_math_macro_suffixes = dict(chain(self.type_math_macro_suffixes.items(),
- settings.pop('type_math_macro_suffixes', {}).items()))
- super().__init__(settings)
- self.known_functions = dict(self._kf, **settings.get('user_functions', {}))
- self._dereference = set(settings.get('dereference', []))
- self.headers = set()
- self.libraries = set()
- self.macros = set()
- def _rate_index_position(self, p):
- return p*5
- def _get_statement(self, codestring):
- """ Get code string as a statement - i.e. ending with a semicolon. """
- return codestring if codestring.endswith(';') else codestring + ';'
- def _get_comment(self, text):
- return "// {}".format(text)
- def _declare_number_const(self, name, value):
- type_ = self.type_aliases[real]
- var = Variable(name, type=type_, value=value.evalf(type_.decimal_dig), attrs={value_const})
- decl = Declaration(var)
- return self._get_statement(self._print(decl))
- def _format_code(self, lines):
- return self.indent_code(lines)
- def _traverse_matrix_indices(self, mat):
- rows, cols = mat.shape
- return ((i, j) for i in range(rows) for j in range(cols))
- @_as_macro_if_defined
- def _print_Mul(self, expr, **kwargs):
- return super()._print_Mul(expr, **kwargs)
- @_as_macro_if_defined
- def _print_Pow(self, expr):
- if "Pow" in self.known_functions:
- return self._print_Function(expr)
- PREC = precedence(expr)
- suffix = self._get_func_suffix(real)
- if expr.exp == -1:
- literal_suffix = self._get_literal_suffix(real)
- return '1.0%s/%s' % (literal_suffix, self.parenthesize(expr.base, PREC))
- elif expr.exp == 0.5:
- return '%ssqrt%s(%s)' % (self._ns, suffix, self._print(expr.base))
- elif expr.exp == S.One/3 and self.standard != 'C89':
- return '%scbrt%s(%s)' % (self._ns, suffix, self._print(expr.base))
- else:
- return '%spow%s(%s, %s)' % (self._ns, suffix, self._print(expr.base),
- self._print(expr.exp))
- def _print_Mod(self, expr):
- num, den = expr.args
- if num.is_integer and den.is_integer:
- PREC = precedence(expr)
- snum, sden = [self.parenthesize(arg, PREC) for arg in expr.args]
- # % is remainder (same sign as numerator), not modulo (same sign as
- # denominator), in C. Hence, % only works as modulo if both numbers
- # have the same sign
- if (num.is_nonnegative and den.is_nonnegative or
- num.is_nonpositive and den.is_nonpositive):
- return f"{snum} % {sden}"
- return f"(({snum} % {sden}) + {sden}) % {sden}"
- # Not guaranteed integer
- return self._print_math_func(expr, known='fmod')
- def _print_Rational(self, expr):
- p, q = int(expr.p), int(expr.q)
- suffix = self._get_literal_suffix(real)
- return '%d.0%s/%d.0%s' % (p, suffix, q, suffix)
- def _print_Indexed(self, expr):
- # calculate index for 1d array
- offset = getattr(expr.base, 'offset', S.Zero)
- strides = getattr(expr.base, 'strides', None)
- indices = expr.indices
- if strides is None or isinstance(strides, str):
- dims = expr.shape
- shift = S.One
- temp = tuple()
- if strides == 'C' or strides is None:
- traversal = reversed(range(expr.rank))
- indices = indices[::-1]
- elif strides == 'F':
- traversal = range(expr.rank)
- for i in traversal:
- temp += (shift,)
- shift *= dims[i]
- strides = temp
- flat_index = sum([x[0]*x[1] for x in zip(indices, strides)]) + offset
- return "%s[%s]" % (self._print(expr.base.label),
- self._print(flat_index))
- def _print_Idx(self, expr):
- return self._print(expr.label)
- @_as_macro_if_defined
- def _print_NumberSymbol(self, expr):
- return super()._print_NumberSymbol(expr)
- def _print_Infinity(self, expr):
- return 'HUGE_VAL'
- def _print_NegativeInfinity(self, expr):
- return '-HUGE_VAL'
- def _print_Piecewise(self, expr):
- if expr.args[-1].cond != True:
- # We need the last conditional to be a True, otherwise the resulting
- # function may not return a result.
- raise ValueError("All Piecewise expressions must contain an "
- "(expr, True) statement to be used as a default "
- "condition. Without one, the generated "
- "expression may not evaluate to anything under "
- "some condition.")
- lines = []
- if expr.has(Assignment):
- for i, (e, c) in enumerate(expr.args):
- if i == 0:
- lines.append("if (%s) {" % self._print(c))
- elif i == len(expr.args) - 1 and c == True:
- lines.append("else {")
- else:
- lines.append("else if (%s) {" % self._print(c))
- code0 = self._print(e)
- lines.append(code0)
- lines.append("}")
- return "\n".join(lines)
- else:
- # The piecewise was used in an expression, need to do inline
- # operators. This has the downside that inline operators will
- # not work for statements that span multiple lines (Matrix or
- # Indexed expressions).
- ecpairs = ["((%s) ? (\n%s\n)\n" % (self._print(c),
- self._print(e))
- for e, c in expr.args[:-1]]
- last_line = ": (\n%s\n)" % self._print(expr.args[-1].expr)
- return ": ".join(ecpairs) + last_line + " ".join([")"*len(ecpairs)])
- def _print_ITE(self, expr):
- from sympy.functions import Piecewise
- return self._print(expr.rewrite(Piecewise, deep=False))
- def _print_MatrixElement(self, expr):
- return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"],
- strict=True), expr.j + expr.i*expr.parent.shape[1])
- def _print_Symbol(self, expr):
- name = super()._print_Symbol(expr)
- if expr in self._settings['dereference']:
- return '(*{})'.format(name)
- else:
- return name
- def _print_Relational(self, expr):
- lhs_code = self._print(expr.lhs)
- rhs_code = self._print(expr.rhs)
- op = expr.rel_op
- return "{} {} {}".format(lhs_code, op, rhs_code)
- def _print_For(self, expr):
- target = self._print(expr.target)
- if isinstance(expr.iterable, Range):
- start, stop, step = expr.iterable.args
- else:
- raise NotImplementedError("Only iterable currently supported is Range")
- body = self._print(expr.body)
- return ('for ({target} = {start}; {target} < {stop}; {target} += '
- '{step}) {{\n{body}\n}}').format(target=target, start=start,
- stop=stop, step=step, body=body)
- def _print_sign(self, func):
- return '((({0}) > 0) - (({0}) < 0))'.format(self._print(func.args[0]))
- def _print_Max(self, expr):
- if "Max" in self.known_functions:
- return self._print_Function(expr)
- def inner_print_max(args): # The more natural abstraction of creating
- if len(args) == 1: # and printing smaller Max objects is slow
- return self._print(args[0]) # when there are many arguments.
- half = len(args) // 2
- return "((%(a)s > %(b)s) ? %(a)s : %(b)s)" % {
- 'a': inner_print_max(args[:half]),
- 'b': inner_print_max(args[half:])
- }
- return inner_print_max(expr.args)
- def _print_Min(self, expr):
- if "Min" in self.known_functions:
- return self._print_Function(expr)
- def inner_print_min(args): # The more natural abstraction of creating
- if len(args) == 1: # and printing smaller Min objects is slow
- return self._print(args[0]) # when there are many arguments.
- half = len(args) // 2
- return "((%(a)s < %(b)s) ? %(a)s : %(b)s)" % {
- 'a': inner_print_min(args[:half]),
- 'b': inner_print_min(args[half:])
- }
- return inner_print_min(expr.args)
- def indent_code(self, code):
- """Accepts a string of code or a list of code lines"""
- if isinstance(code, str):
- code_lines = self.indent_code(code.splitlines(True))
- return ''.join(code_lines)
- tab = " "
- inc_token = ('{', '(', '{\n', '(\n')
- dec_token = ('}', ')')
- code = [line.lstrip(' \t') for line in code]
- increase = [int(any(map(line.endswith, inc_token))) for line in code]
- decrease = [int(any(map(line.startswith, dec_token))) for line in code]
- pretty = []
- level = 0
- for n, line in enumerate(code):
- if line in ('', '\n'):
- pretty.append(line)
- continue
- level -= decrease[n]
- pretty.append("%s%s" % (tab*level, line))
- level += increase[n]
- return pretty
- def _get_func_suffix(self, type_):
- return self.type_func_suffixes[self.type_aliases.get(type_, type_)]
- def _get_literal_suffix(self, type_):
- return self.type_literal_suffixes[self.type_aliases.get(type_, type_)]
- def _get_math_macro_suffix(self, type_):
- alias = self.type_aliases.get(type_, type_)
- dflt = self.type_math_macro_suffixes.get(alias, '')
- return self.type_math_macro_suffixes.get(type_, dflt)
- def _print_Tuple(self, expr):
- return '{'+', '.join(self._print(e) for e in expr)+'}'
- _print_List = _print_Tuple
- def _print_Type(self, type_):
- self.headers.update(self.type_headers.get(type_, set()))
- self.macros.update(self.type_macros.get(type_, set()))
- return self._print(self.type_mappings.get(type_, type_.name))
- def _print_Declaration(self, decl):
- from sympy.codegen.cnodes import restrict
- var = decl.variable
- val = var.value
- if var.type == untyped:
- raise ValueError("C does not support untyped variables")
- if isinstance(var, Pointer):
- result = '{vc}{t} *{pc} {r}{s}'.format(
- vc='const ' if value_const in var.attrs else '',
- t=self._print(var.type),
- pc=' const' if pointer_const in var.attrs else '',
- r='restrict ' if restrict in var.attrs else '',
- s=self._print(var.symbol)
- )
- elif isinstance(var, Variable):
- result = '{vc}{t} {s}'.format(
- vc='const ' if value_const in var.attrs else '',
- t=self._print(var.type),
- s=self._print(var.symbol)
- )
- else:
- raise NotImplementedError("Unknown type of var: %s" % type(var))
- if val != None: # Must be "!= None", cannot be "is not None"
- result += ' = %s' % self._print(val)
- return result
- def _print_Float(self, flt):
- type_ = self.type_aliases.get(real, real)
- self.macros.update(self.type_macros.get(type_, set()))
- suffix = self._get_literal_suffix(type_)
- num = str(flt.evalf(type_.decimal_dig))
- if 'e' not in num and '.' not in num:
- num += '.0'
- num_parts = num.split('e')
- num_parts[0] = num_parts[0].rstrip('0')
- if num_parts[0].endswith('.'):
- num_parts[0] += '0'
- return 'e'.join(num_parts) + suffix
- @requires(headers={'stdbool.h'})
- def _print_BooleanTrue(self, expr):
- return 'true'
- @requires(headers={'stdbool.h'})
- def _print_BooleanFalse(self, expr):
- return 'false'
- def _print_Element(self, elem):
- if elem.strides == None: # Must be "== None", cannot be "is None"
- if elem.offset != None: # Must be "!= None", cannot be "is not None"
- raise ValueError("Expected strides when offset is given")
- idxs = ']['.join(map(lambda arg: self._print(arg),
- elem.indices))
- else:
- global_idx = sum([i*s for i, s in zip(elem.indices, elem.strides)])
- if elem.offset != None: # Must be "!= None", cannot be "is not None"
- global_idx += elem.offset
- idxs = self._print(global_idx)
- return "{symb}[{idxs}]".format(
- symb=self._print(elem.symbol),
- idxs=idxs
- )
- def _print_CodeBlock(self, expr):
- """ Elements of code blocks printed as statements. """
- return '\n'.join([self._get_statement(self._print(i)) for i in expr.args])
- def _print_While(self, expr):
- return 'while ({condition}) {{\n{body}\n}}'.format(**expr.kwargs(
- apply=lambda arg: self._print(arg)))
- def _print_Scope(self, expr):
- return '{\n%s\n}' % self._print_CodeBlock(expr.body)
- @requires(headers={'stdio.h'})
- def _print_Print(self, expr):
- return 'printf({fmt}, {pargs})'.format(
- fmt=self._print(expr.format_string),
- pargs=', '.join(map(lambda arg: self._print(arg), expr.print_args))
- )
- def _print_FunctionPrototype(self, expr):
- pars = ', '.join(map(lambda arg: self._print(Declaration(arg)),
- expr.parameters))
- return "%s %s(%s)" % (
- tuple(map(lambda arg: self._print(arg),
- (expr.return_type, expr.name))) + (pars,)
- )
- def _print_FunctionDefinition(self, expr):
- return "%s%s" % (self._print_FunctionPrototype(expr),
- self._print_Scope(expr))
- def _print_Return(self, expr):
- arg, = expr.args
- return 'return %s' % self._print(arg)
- def _print_CommaOperator(self, expr):
- return '(%s)' % ', '.join(map(lambda arg: self._print(arg), expr.args))
- def _print_Label(self, expr):
- if expr.body == none:
- return '%s:' % str(expr.name)
- if len(expr.body.args) == 1:
- return '%s:\n%s' % (str(expr.name), self._print_CodeBlock(expr.body))
- return '%s:\n{\n%s\n}' % (str(expr.name), self._print_CodeBlock(expr.body))
- def _print_goto(self, expr):
- return 'goto %s' % expr.label.name
- def _print_PreIncrement(self, expr):
- arg, = expr.args
- return '++(%s)' % self._print(arg)
- def _print_PostIncrement(self, expr):
- arg, = expr.args
- return '(%s)++' % self._print(arg)
- def _print_PreDecrement(self, expr):
- arg, = expr.args
- return '--(%s)' % self._print(arg)
- def _print_PostDecrement(self, expr):
- arg, = expr.args
- return '(%s)--' % self._print(arg)
- def _print_struct(self, expr):
- return "%(keyword)s %(name)s {\n%(lines)s}" % dict(
- keyword=expr.__class__.__name__, name=expr.name, lines=';\n'.join(
- [self._print(decl) for decl in expr.declarations] + [''])
- )
- def _print_BreakToken(self, _):
- return 'break'
- def _print_ContinueToken(self, _):
- return 'continue'
- _print_union = _print_struct
- class C99CodePrinter(C89CodePrinter):
- standard = 'C99'
- reserved_words = set(reserved_words + reserved_words_c99)
- type_mappings=dict(chain(C89CodePrinter.type_mappings.items(), {
- complex64: 'float complex',
- complex128: 'double complex',
- }.items()))
- type_headers = dict(chain(C89CodePrinter.type_headers.items(), {
- complex64: {'complex.h'},
- complex128: {'complex.h'}
- }.items()))
- # known_functions-dict to copy
- _kf = known_functions_C99 # type: tDict[str, Any]
- # functions with versions with 'f' and 'l' suffixes:
- _prec_funcs = ('fabs fmod remainder remquo fma fmax fmin fdim nan exp exp2'
- ' expm1 log log10 log2 log1p pow sqrt cbrt hypot sin cos tan'
- ' asin acos atan atan2 sinh cosh tanh asinh acosh atanh erf'
- ' erfc tgamma lgamma ceil floor trunc round nearbyint rint'
- ' frexp ldexp modf scalbn ilogb logb nextafter copysign').split()
- def _print_Infinity(self, expr):
- return 'INFINITY'
- def _print_NegativeInfinity(self, expr):
- return '-INFINITY'
- def _print_NaN(self, expr):
- return 'NAN'
- # tgamma was already covered by 'known_functions' dict
- @requires(headers={'math.h'}, libraries={'m'})
- @_as_macro_if_defined
- def _print_math_func(self, expr, nest=False, known=None):
- if known is None:
- known = self.known_functions[expr.__class__.__name__]
- if not isinstance(known, str):
- for cb, name in known:
- if cb(*expr.args):
- known = name
- break
- else:
- raise ValueError("No matching printer")
- try:
- return known(self, *expr.args)
- except TypeError:
- suffix = self._get_func_suffix(real) if self._ns + known in self._prec_funcs else ''
- if nest:
- args = self._print(expr.args[0])
- if len(expr.args) > 1:
- paren_pile = ''
- for curr_arg in expr.args[1:-1]:
- paren_pile += ')'
- args += ', {ns}{name}{suffix}({next}'.format(
- ns=self._ns,
- name=known,
- suffix=suffix,
- next = self._print(curr_arg)
- )
- args += ', %s%s' % (
- self._print(expr.func(expr.args[-1])),
- paren_pile
- )
- else:
- args = ', '.join(map(lambda arg: self._print(arg), expr.args))
- return '{ns}{name}{suffix}({args})'.format(
- ns=self._ns,
- name=known,
- suffix=suffix,
- args=args
- )
- def _print_Max(self, expr):
- return self._print_math_func(expr, nest=True)
- def _print_Min(self, expr):
- return self._print_math_func(expr, nest=True)
- def _get_loop_opening_ending(self, indices):
- open_lines = []
- close_lines = []
- loopstart = "for (int %(var)s=%(start)s; %(var)s<%(end)s; %(var)s++){" # C99
- for i in indices:
- # C arrays start at 0 and end at dimension-1
- open_lines.append(loopstart % {
- 'var': self._print(i.label),
- 'start': self._print(i.lower),
- 'end': self._print(i.upper + 1)})
- close_lines.append("}")
- return open_lines, close_lines
- for k in ('Abs Sqrt exp exp2 expm1 log log10 log2 log1p Cbrt hypot fma'
- ' loggamma sin cos tan asin acos atan atan2 sinh cosh tanh asinh acosh '
- 'atanh erf erfc loggamma gamma ceiling floor').split():
- setattr(C99CodePrinter, '_print_%s' % k, C99CodePrinter._print_math_func)
- class C11CodePrinter(C99CodePrinter):
- @requires(headers={'stdalign.h'})
- def _print_alignof(self, expr):
- arg, = expr.args
- return 'alignof(%s)' % self._print(arg)
- c_code_printers = {
- 'c89': C89CodePrinter,
- 'c99': C99CodePrinter,
- 'c11': C11CodePrinter
- }
|