rcode.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410
  1. """
  2. R code printer
  3. The RCodePrinter converts single SymPy expressions into single R expressions,
  4. using the functions defined in math.h where possible.
  5. """
  6. from typing import Any, Dict as tDict
  7. from sympy.printing.codeprinter import CodePrinter
  8. from sympy.printing.precedence import precedence, PRECEDENCE
  9. from sympy.sets.fancysets import Range
  10. # dictionary mapping SymPy function to (argument_conditions, C_function).
  11. # Used in RCodePrinter._print_Function(self)
  12. known_functions = {
  13. #"Abs": [(lambda x: not x.is_integer, "fabs")],
  14. "Abs": "abs",
  15. "sin": "sin",
  16. "cos": "cos",
  17. "tan": "tan",
  18. "asin": "asin",
  19. "acos": "acos",
  20. "atan": "atan",
  21. "atan2": "atan2",
  22. "exp": "exp",
  23. "log": "log",
  24. "erf": "erf",
  25. "sinh": "sinh",
  26. "cosh": "cosh",
  27. "tanh": "tanh",
  28. "asinh": "asinh",
  29. "acosh": "acosh",
  30. "atanh": "atanh",
  31. "floor": "floor",
  32. "ceiling": "ceiling",
  33. "sign": "sign",
  34. "Max": "max",
  35. "Min": "min",
  36. "factorial": "factorial",
  37. "gamma": "gamma",
  38. "digamma": "digamma",
  39. "trigamma": "trigamma",
  40. "beta": "beta",
  41. "sqrt": "sqrt", # To enable automatic rewrite
  42. }
  43. # These are the core reserved words in the R language. Taken from:
  44. # https://cran.r-project.org/doc/manuals/r-release/R-lang.html#Reserved-words
  45. reserved_words = ['if',
  46. 'else',
  47. 'repeat',
  48. 'while',
  49. 'function',
  50. 'for',
  51. 'in',
  52. 'next',
  53. 'break',
  54. 'TRUE',
  55. 'FALSE',
  56. 'NULL',
  57. 'Inf',
  58. 'NaN',
  59. 'NA',
  60. 'NA_integer_',
  61. 'NA_real_',
  62. 'NA_complex_',
  63. 'NA_character_',
  64. 'volatile']
  65. class RCodePrinter(CodePrinter):
  66. """A printer to convert SymPy expressions to strings of R code"""
  67. printmethod = "_rcode"
  68. language = "R"
  69. _default_settings = {
  70. 'order': None,
  71. 'full_prec': 'auto',
  72. 'precision': 15,
  73. 'user_functions': {},
  74. 'human': True,
  75. 'contract': True,
  76. 'dereference': set(),
  77. 'error_on_reserved': False,
  78. 'reserved_word_suffix': '_',
  79. } # type: tDict[str, Any]
  80. _operators = {
  81. 'and': '&',
  82. 'or': '|',
  83. 'not': '!',
  84. }
  85. _relationals = {
  86. } # type: tDict[str, str]
  87. def __init__(self, settings={}):
  88. CodePrinter.__init__(self, settings)
  89. self.known_functions = dict(known_functions)
  90. userfuncs = settings.get('user_functions', {})
  91. self.known_functions.update(userfuncs)
  92. self._dereference = set(settings.get('dereference', []))
  93. self.reserved_words = set(reserved_words)
  94. def _rate_index_position(self, p):
  95. return p*5
  96. def _get_statement(self, codestring):
  97. return "%s;" % codestring
  98. def _get_comment(self, text):
  99. return "// {}".format(text)
  100. def _declare_number_const(self, name, value):
  101. return "{} = {};".format(name, value)
  102. def _format_code(self, lines):
  103. return self.indent_code(lines)
  104. def _traverse_matrix_indices(self, mat):
  105. rows, cols = mat.shape
  106. return ((i, j) for i in range(rows) for j in range(cols))
  107. def _get_loop_opening_ending(self, indices):
  108. """Returns a tuple (open_lines, close_lines) containing lists of codelines
  109. """
  110. open_lines = []
  111. close_lines = []
  112. loopstart = "for (%(var)s in %(start)s:%(end)s){"
  113. for i in indices:
  114. # R arrays start at 1 and end at dimension
  115. open_lines.append(loopstart % {
  116. 'var': self._print(i.label),
  117. 'start': self._print(i.lower+1),
  118. 'end': self._print(i.upper + 1)})
  119. close_lines.append("}")
  120. return open_lines, close_lines
  121. def _print_Pow(self, expr):
  122. if "Pow" in self.known_functions:
  123. return self._print_Function(expr)
  124. PREC = precedence(expr)
  125. if expr.exp == -1:
  126. return '1.0/%s' % (self.parenthesize(expr.base, PREC))
  127. elif expr.exp == 0.5:
  128. return 'sqrt(%s)' % self._print(expr.base)
  129. else:
  130. return '%s^%s' % (self.parenthesize(expr.base, PREC),
  131. self.parenthesize(expr.exp, PREC))
  132. def _print_Rational(self, expr):
  133. p, q = int(expr.p), int(expr.q)
  134. return '%d.0/%d.0' % (p, q)
  135. def _print_Indexed(self, expr):
  136. inds = [ self._print(i) for i in expr.indices ]
  137. return "%s[%s]" % (self._print(expr.base.label), ", ".join(inds))
  138. def _print_Idx(self, expr):
  139. return self._print(expr.label)
  140. def _print_Exp1(self, expr):
  141. return "exp(1)"
  142. def _print_Pi(self, expr):
  143. return 'pi'
  144. def _print_Infinity(self, expr):
  145. return 'Inf'
  146. def _print_NegativeInfinity(self, expr):
  147. return '-Inf'
  148. def _print_Assignment(self, expr):
  149. from sympy.codegen.ast import Assignment
  150. from sympy.matrices.expressions.matexpr import MatrixSymbol
  151. from sympy.tensor.indexed import IndexedBase
  152. lhs = expr.lhs
  153. rhs = expr.rhs
  154. # We special case assignments that take multiple lines
  155. #if isinstance(expr.rhs, Piecewise):
  156. # from sympy.functions.elementary.piecewise import Piecewise
  157. # # Here we modify Piecewise so each expression is now
  158. # # an Assignment, and then continue on the print.
  159. # expressions = []
  160. # conditions = []
  161. # for (e, c) in rhs.args:
  162. # expressions.append(Assignment(lhs, e))
  163. # conditions.append(c)
  164. # temp = Piecewise(*zip(expressions, conditions))
  165. # return self._print(temp)
  166. #elif isinstance(lhs, MatrixSymbol):
  167. if isinstance(lhs, MatrixSymbol):
  168. # Here we form an Assignment for each element in the array,
  169. # printing each one.
  170. lines = []
  171. for (i, j) in self._traverse_matrix_indices(lhs):
  172. temp = Assignment(lhs[i, j], rhs[i, j])
  173. code0 = self._print(temp)
  174. lines.append(code0)
  175. return "\n".join(lines)
  176. elif self._settings["contract"] and (lhs.has(IndexedBase) or
  177. rhs.has(IndexedBase)):
  178. # Here we check if there is looping to be done, and if so
  179. # print the required loops.
  180. return self._doprint_loops(rhs, lhs)
  181. else:
  182. lhs_code = self._print(lhs)
  183. rhs_code = self._print(rhs)
  184. return self._get_statement("%s = %s" % (lhs_code, rhs_code))
  185. def _print_Piecewise(self, expr):
  186. # This method is called only for inline if constructs
  187. # Top level piecewise is handled in doprint()
  188. if expr.args[-1].cond == True:
  189. last_line = "%s" % self._print(expr.args[-1].expr)
  190. else:
  191. last_line = "ifelse(%s,%s,NA)" % (self._print(expr.args[-1].cond), self._print(expr.args[-1].expr))
  192. code=last_line
  193. for e, c in reversed(expr.args[:-1]):
  194. code= "ifelse(%s,%s," % (self._print(c), self._print(e))+code+")"
  195. return(code)
  196. def _print_ITE(self, expr):
  197. from sympy.functions import Piecewise
  198. return self._print(expr.rewrite(Piecewise))
  199. def _print_MatrixElement(self, expr):
  200. return "{}[{}]".format(self.parenthesize(expr.parent, PRECEDENCE["Atom"],
  201. strict=True), expr.j + expr.i*expr.parent.shape[1])
  202. def _print_Symbol(self, expr):
  203. name = super()._print_Symbol(expr)
  204. if expr in self._dereference:
  205. return '(*{})'.format(name)
  206. else:
  207. return name
  208. def _print_Relational(self, expr):
  209. lhs_code = self._print(expr.lhs)
  210. rhs_code = self._print(expr.rhs)
  211. op = expr.rel_op
  212. return "{} {} {}".format(lhs_code, op, rhs_code)
  213. def _print_AugmentedAssignment(self, expr):
  214. lhs_code = self._print(expr.lhs)
  215. op = expr.op
  216. rhs_code = self._print(expr.rhs)
  217. return "{} {} {};".format(lhs_code, op, rhs_code)
  218. def _print_For(self, expr):
  219. target = self._print(expr.target)
  220. if isinstance(expr.iterable, Range):
  221. start, stop, step = expr.iterable.args
  222. else:
  223. raise NotImplementedError("Only iterable currently supported is Range")
  224. body = self._print(expr.body)
  225. return ('for ({target} = {start}; {target} < {stop}; {target} += '
  226. '{step}) {{\n{body}\n}}').format(target=target, start=start,
  227. stop=stop, step=step, body=body)
  228. def indent_code(self, code):
  229. """Accepts a string of code or a list of code lines"""
  230. if isinstance(code, str):
  231. code_lines = self.indent_code(code.splitlines(True))
  232. return ''.join(code_lines)
  233. tab = " "
  234. inc_token = ('{', '(', '{\n', '(\n')
  235. dec_token = ('}', ')')
  236. code = [ line.lstrip(' \t') for line in code ]
  237. increase = [ int(any(map(line.endswith, inc_token))) for line in code ]
  238. decrease = [ int(any(map(line.startswith, dec_token)))
  239. for line in code ]
  240. pretty = []
  241. level = 0
  242. for n, line in enumerate(code):
  243. if line in ('', '\n'):
  244. pretty.append(line)
  245. continue
  246. level -= decrease[n]
  247. pretty.append("%s%s" % (tab*level, line))
  248. level += increase[n]
  249. return pretty
  250. def rcode(expr, assign_to=None, **settings):
  251. """Converts an expr to a string of r code
  252. Parameters
  253. ==========
  254. expr : Expr
  255. A SymPy expression to be converted.
  256. assign_to : optional
  257. When given, the argument is used as the name of the variable to which
  258. the expression is assigned. Can be a string, ``Symbol``,
  259. ``MatrixSymbol``, or ``Indexed`` type. This is helpful in case of
  260. line-wrapping, or for expressions that generate multi-line statements.
  261. precision : integer, optional
  262. The precision for numbers such as pi [default=15].
  263. user_functions : dict, optional
  264. A dictionary where the keys are string representations of either
  265. ``FunctionClass`` or ``UndefinedFunction`` instances and the values
  266. are their desired R string representations. Alternatively, the
  267. dictionary value can be a list of tuples i.e. [(argument_test,
  268. rfunction_string)] or [(argument_test, rfunction_formater)]. See below
  269. for examples.
  270. human : bool, optional
  271. If True, the result is a single string that may contain some constant
  272. declarations for the number symbols. If False, the same information is
  273. returned in a tuple of (symbols_to_declare, not_supported_functions,
  274. code_text). [default=True].
  275. contract: bool, optional
  276. If True, ``Indexed`` instances are assumed to obey tensor contraction
  277. rules and the corresponding nested loops over indices are generated.
  278. Setting contract=False will not generate loops, instead the user is
  279. responsible to provide values for the indices in the code.
  280. [default=True].
  281. Examples
  282. ========
  283. >>> from sympy import rcode, symbols, Rational, sin, ceiling, Abs, Function
  284. >>> x, tau = symbols("x, tau")
  285. >>> rcode((2*tau)**Rational(7, 2))
  286. '8*sqrt(2)*tau^(7.0/2.0)'
  287. >>> rcode(sin(x), assign_to="s")
  288. 's = sin(x);'
  289. Simple custom printing can be defined for certain types by passing a
  290. dictionary of {"type" : "function"} to the ``user_functions`` kwarg.
  291. Alternatively, the dictionary value can be a list of tuples i.e.
  292. [(argument_test, cfunction_string)].
  293. >>> custom_functions = {
  294. ... "ceiling": "CEIL",
  295. ... "Abs": [(lambda x: not x.is_integer, "fabs"),
  296. ... (lambda x: x.is_integer, "ABS")],
  297. ... "func": "f"
  298. ... }
  299. >>> func = Function('func')
  300. >>> rcode(func(Abs(x) + ceiling(x)), user_functions=custom_functions)
  301. 'f(fabs(x) + CEIL(x))'
  302. or if the R-function takes a subset of the original arguments:
  303. >>> rcode(2**x + 3**x, user_functions={'Pow': [
  304. ... (lambda b, e: b == 2, lambda b, e: 'exp2(%s)' % e),
  305. ... (lambda b, e: b != 2, 'pow')]})
  306. 'exp2(x) + pow(3, x)'
  307. ``Piecewise`` expressions are converted into conditionals. If an
  308. ``assign_to`` variable is provided an if statement is created, otherwise
  309. the ternary operator is used. Note that if the ``Piecewise`` lacks a
  310. default term, represented by ``(expr, True)`` then an error will be thrown.
  311. This is to prevent generating an expression that may not evaluate to
  312. anything.
  313. >>> from sympy import Piecewise
  314. >>> expr = Piecewise((x + 1, x > 0), (x, True))
  315. >>> print(rcode(expr, assign_to=tau))
  316. tau = ifelse(x > 0,x + 1,x);
  317. Support for loops is provided through ``Indexed`` types. With
  318. ``contract=True`` these expressions will be turned into loops, whereas
  319. ``contract=False`` will just print the assignment expression that should be
  320. looped over:
  321. >>> from sympy import Eq, IndexedBase, Idx
  322. >>> len_y = 5
  323. >>> y = IndexedBase('y', shape=(len_y,))
  324. >>> t = IndexedBase('t', shape=(len_y,))
  325. >>> Dy = IndexedBase('Dy', shape=(len_y-1,))
  326. >>> i = Idx('i', len_y-1)
  327. >>> e=Eq(Dy[i], (y[i+1]-y[i])/(t[i+1]-t[i]))
  328. >>> rcode(e.rhs, assign_to=e.lhs, contract=False)
  329. 'Dy[i] = (y[i + 1] - y[i])/(t[i + 1] - t[i]);'
  330. Matrices are also supported, but a ``MatrixSymbol`` of the same dimensions
  331. must be provided to ``assign_to``. Note that any expression that can be
  332. generated normally can also exist inside a Matrix:
  333. >>> from sympy import Matrix, MatrixSymbol
  334. >>> mat = Matrix([x**2, Piecewise((x + 1, x > 0), (x, True)), sin(x)])
  335. >>> A = MatrixSymbol('A', 3, 1)
  336. >>> print(rcode(mat, A))
  337. A[0] = x^2;
  338. A[1] = ifelse(x > 0,x + 1,x);
  339. A[2] = sin(x);
  340. """
  341. return RCodePrinter(settings).doprint(expr, assign_to)
  342. def print_rcode(expr, **settings):
  343. """Prints R representation of the given expression."""
  344. print(rcode(expr, **settings))