codegen.py 80 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236
  1. """
  2. module for generating C, C++, Fortran77, Fortran90, Julia, Rust
  3. and Octave/Matlab routines that evaluate SymPy expressions.
  4. This module is work in progress.
  5. Only the milestones with a '+' character in the list below have been completed.
  6. --- How is sympy.utilities.codegen different from sympy.printing.ccode? ---
  7. We considered the idea to extend the printing routines for SymPy functions in
  8. such a way that it prints complete compilable code, but this leads to a few
  9. unsurmountable issues that can only be tackled with dedicated code generator:
  10. - For C, one needs both a code and a header file, while the printing routines
  11. generate just one string. This code generator can be extended to support
  12. .pyf files for f2py.
  13. - SymPy functions are not concerned with programming-technical issues, such
  14. as input, output and input-output arguments. Other examples are contiguous
  15. or non-contiguous arrays, including headers of other libraries such as gsl
  16. or others.
  17. - It is highly interesting to evaluate several SymPy functions in one C
  18. routine, eventually sharing common intermediate results with the help
  19. of the cse routine. This is more than just printing.
  20. - From the programming perspective, expressions with constants should be
  21. evaluated in the code generator as much as possible. This is different
  22. for printing.
  23. --- Basic assumptions ---
  24. * A generic Routine data structure describes the routine that must be
  25. translated into C/Fortran/... code. This data structure covers all
  26. features present in one or more of the supported languages.
  27. * Descendants from the CodeGen class transform multiple Routine instances
  28. into compilable code. Each derived class translates into a specific
  29. language.
  30. * In many cases, one wants a simple workflow. The friendly functions in the
  31. last part are a simple api on top of the Routine/CodeGen stuff. They are
  32. easier to use, but are less powerful.
  33. --- Milestones ---
  34. + First working version with scalar input arguments, generating C code,
  35. tests
  36. + Friendly functions that are easier to use than the rigorous
  37. Routine/CodeGen workflow.
  38. + Integer and Real numbers as input and output
  39. + Output arguments
  40. + InputOutput arguments
  41. + Sort input/output arguments properly
  42. + Contiguous array arguments (numpy matrices)
  43. + Also generate .pyf code for f2py (in autowrap module)
  44. + Isolate constants and evaluate them beforehand in double precision
  45. + Fortran 90
  46. + Octave/Matlab
  47. - Common Subexpression Elimination
  48. - User defined comments in the generated code
  49. - Optional extra include lines for libraries/objects that can eval special
  50. functions
  51. - Test other C compilers and libraries: gcc, tcc, libtcc, gcc+gsl, ...
  52. - Contiguous array arguments (SymPy matrices)
  53. - Non-contiguous array arguments (SymPy matrices)
  54. - ccode must raise an error when it encounters something that cannot be
  55. translated into c. ccode(integrate(sin(x)/x, x)) does not make sense.
  56. - Complex numbers as input and output
  57. - A default complex datatype
  58. - Include extra information in the header: date, user, hostname, sha1
  59. hash, ...
  60. - Fortran 77
  61. - C++
  62. - Python
  63. - Julia
  64. - Rust
  65. - ...
  66. """
  67. import os
  68. import textwrap
  69. from io import StringIO
  70. from sympy import __version__ as sympy_version
  71. from sympy.core import Symbol, S, Tuple, Equality, Function, Basic
  72. from sympy.printing.c import c_code_printers
  73. from sympy.printing.codeprinter import AssignmentError
  74. from sympy.printing.fortran import FCodePrinter
  75. from sympy.printing.julia import JuliaCodePrinter
  76. from sympy.printing.octave import OctaveCodePrinter
  77. from sympy.printing.rust import RustCodePrinter
  78. from sympy.tensor import Idx, Indexed, IndexedBase
  79. from sympy.matrices import (MatrixSymbol, ImmutableMatrix, MatrixBase,
  80. MatrixExpr, MatrixSlice)
  81. from sympy.utilities.iterables import is_sequence
  82. __all__ = [
  83. # description of routines
  84. "Routine", "DataType", "default_datatypes", "get_default_datatype",
  85. "Argument", "InputArgument", "OutputArgument", "Result",
  86. # routines -> code
  87. "CodeGen", "CCodeGen", "FCodeGen", "JuliaCodeGen", "OctaveCodeGen",
  88. "RustCodeGen",
  89. # friendly functions
  90. "codegen", "make_routine",
  91. ]
  92. #
  93. # Description of routines
  94. #
  95. class Routine:
  96. """Generic description of evaluation routine for set of expressions.
  97. A CodeGen class can translate instances of this class into code in a
  98. particular language. The routine specification covers all the features
  99. present in these languages. The CodeGen part must raise an exception
  100. when certain features are not present in the target language. For
  101. example, multiple return values are possible in Python, but not in C or
  102. Fortran. Another example: Fortran and Python support complex numbers,
  103. while C does not.
  104. """
  105. def __init__(self, name, arguments, results, local_vars, global_vars):
  106. """Initialize a Routine instance.
  107. Parameters
  108. ==========
  109. name : string
  110. Name of the routine.
  111. arguments : list of Arguments
  112. These are things that appear in arguments of a routine, often
  113. appearing on the right-hand side of a function call. These are
  114. commonly InputArguments but in some languages, they can also be
  115. OutputArguments or InOutArguments (e.g., pass-by-reference in C
  116. code).
  117. results : list of Results
  118. These are the return values of the routine, often appearing on
  119. the left-hand side of a function call. The difference between
  120. Results and OutputArguments and when you should use each is
  121. language-specific.
  122. local_vars : list of Results
  123. These are variables that will be defined at the beginning of the
  124. function.
  125. global_vars : list of Symbols
  126. Variables which will not be passed into the function.
  127. """
  128. # extract all input symbols and all symbols appearing in an expression
  129. input_symbols = set()
  130. symbols = set()
  131. for arg in arguments:
  132. if isinstance(arg, OutputArgument):
  133. symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
  134. elif isinstance(arg, InputArgument):
  135. input_symbols.add(arg.name)
  136. elif isinstance(arg, InOutArgument):
  137. input_symbols.add(arg.name)
  138. symbols.update(arg.expr.free_symbols - arg.expr.atoms(Indexed))
  139. else:
  140. raise ValueError("Unknown Routine argument: %s" % arg)
  141. for r in results:
  142. if not isinstance(r, Result):
  143. raise ValueError("Unknown Routine result: %s" % r)
  144. symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
  145. local_symbols = set()
  146. for r in local_vars:
  147. if isinstance(r, Result):
  148. symbols.update(r.expr.free_symbols - r.expr.atoms(Indexed))
  149. local_symbols.add(r.name)
  150. else:
  151. local_symbols.add(r)
  152. symbols = {s.label if isinstance(s, Idx) else s for s in symbols}
  153. # Check that all symbols in the expressions are covered by
  154. # InputArguments/InOutArguments---subset because user could
  155. # specify additional (unused) InputArguments or local_vars.
  156. notcovered = symbols.difference(
  157. input_symbols.union(local_symbols).union(global_vars))
  158. if notcovered != set():
  159. raise ValueError("Symbols needed for output are not in input " +
  160. ", ".join([str(x) for x in notcovered]))
  161. self.name = name
  162. self.arguments = arguments
  163. self.results = results
  164. self.local_vars = local_vars
  165. self.global_vars = global_vars
  166. def __str__(self):
  167. return self.__class__.__name__ + "({name!r}, {arguments}, {results}, {local_vars}, {global_vars})".format(**self.__dict__)
  168. __repr__ = __str__
  169. @property
  170. def variables(self):
  171. """Returns a set of all variables possibly used in the routine.
  172. For routines with unnamed return values, the dummies that may or
  173. may not be used will be included in the set.
  174. """
  175. v = set(self.local_vars)
  176. for arg in self.arguments:
  177. v.add(arg.name)
  178. for res in self.results:
  179. v.add(res.result_var)
  180. return v
  181. @property
  182. def result_variables(self):
  183. """Returns a list of OutputArgument, InOutArgument and Result.
  184. If return values are present, they are at the end ot the list.
  185. """
  186. args = [arg for arg in self.arguments if isinstance(
  187. arg, (OutputArgument, InOutArgument))]
  188. args.extend(self.results)
  189. return args
  190. class DataType:
  191. """Holds strings for a certain datatype in different languages."""
  192. def __init__(self, cname, fname, pyname, jlname, octname, rsname):
  193. self.cname = cname
  194. self.fname = fname
  195. self.pyname = pyname
  196. self.jlname = jlname
  197. self.octname = octname
  198. self.rsname = rsname
  199. default_datatypes = {
  200. "int": DataType("int", "INTEGER*4", "int", "", "", "i32"),
  201. "float": DataType("double", "REAL*8", "float", "", "", "f64"),
  202. "complex": DataType("double", "COMPLEX*16", "complex", "", "", "float") #FIXME:
  203. # complex is only supported in fortran, python, julia, and octave.
  204. # So to not break c or rust code generation, we stick with double or
  205. # float, respecitvely (but actually should raise an exception for
  206. # explicitly complex variables (x.is_complex==True))
  207. }
  208. COMPLEX_ALLOWED = False
  209. def get_default_datatype(expr, complex_allowed=None):
  210. """Derives an appropriate datatype based on the expression."""
  211. if complex_allowed is None:
  212. complex_allowed = COMPLEX_ALLOWED
  213. if complex_allowed:
  214. final_dtype = "complex"
  215. else:
  216. final_dtype = "float"
  217. if expr.is_integer:
  218. return default_datatypes["int"]
  219. elif expr.is_real:
  220. return default_datatypes["float"]
  221. elif isinstance(expr, MatrixBase):
  222. #check all entries
  223. dt = "int"
  224. for element in expr:
  225. if dt == "int" and not element.is_integer:
  226. dt = "float"
  227. if dt == "float" and not element.is_real:
  228. return default_datatypes[final_dtype]
  229. return default_datatypes[dt]
  230. else:
  231. return default_datatypes[final_dtype]
  232. class Variable:
  233. """Represents a typed variable."""
  234. def __init__(self, name, datatype=None, dimensions=None, precision=None):
  235. """Return a new variable.
  236. Parameters
  237. ==========
  238. name : Symbol or MatrixSymbol
  239. datatype : optional
  240. When not given, the data type will be guessed based on the
  241. assumptions on the symbol argument.
  242. dimension : sequence containing tupes, optional
  243. If present, the argument is interpreted as an array, where this
  244. sequence of tuples specifies (lower, upper) bounds for each
  245. index of the array.
  246. precision : int, optional
  247. Controls the precision of floating point constants.
  248. """
  249. if not isinstance(name, (Symbol, MatrixSymbol)):
  250. raise TypeError("The first argument must be a SymPy symbol.")
  251. if datatype is None:
  252. datatype = get_default_datatype(name)
  253. elif not isinstance(datatype, DataType):
  254. raise TypeError("The (optional) `datatype' argument must be an "
  255. "instance of the DataType class.")
  256. if dimensions and not isinstance(dimensions, (tuple, list)):
  257. raise TypeError(
  258. "The dimension argument must be a sequence of tuples")
  259. self._name = name
  260. self._datatype = {
  261. 'C': datatype.cname,
  262. 'FORTRAN': datatype.fname,
  263. 'JULIA': datatype.jlname,
  264. 'OCTAVE': datatype.octname,
  265. 'PYTHON': datatype.pyname,
  266. 'RUST': datatype.rsname,
  267. }
  268. self.dimensions = dimensions
  269. self.precision = precision
  270. def __str__(self):
  271. return "%s(%r)" % (self.__class__.__name__, self.name)
  272. __repr__ = __str__
  273. @property
  274. def name(self):
  275. return self._name
  276. def get_datatype(self, language):
  277. """Returns the datatype string for the requested language.
  278. Examples
  279. ========
  280. >>> from sympy import Symbol
  281. >>> from sympy.utilities.codegen import Variable
  282. >>> x = Variable(Symbol('x'))
  283. >>> x.get_datatype('c')
  284. 'double'
  285. >>> x.get_datatype('fortran')
  286. 'REAL*8'
  287. """
  288. try:
  289. return self._datatype[language.upper()]
  290. except KeyError:
  291. raise CodeGenError("Has datatypes for languages: %s" %
  292. ", ".join(self._datatype))
  293. class Argument(Variable):
  294. """An abstract Argument data structure: a name and a data type.
  295. This structure is refined in the descendants below.
  296. """
  297. pass
  298. class InputArgument(Argument):
  299. pass
  300. class ResultBase:
  301. """Base class for all "outgoing" information from a routine.
  302. Objects of this class stores a SymPy expression, and a SymPy object
  303. representing a result variable that will be used in the generated code
  304. only if necessary.
  305. """
  306. def __init__(self, expr, result_var):
  307. self.expr = expr
  308. self.result_var = result_var
  309. def __str__(self):
  310. return "%s(%r, %r)" % (self.__class__.__name__, self.expr,
  311. self.result_var)
  312. __repr__ = __str__
  313. class OutputArgument(Argument, ResultBase):
  314. """OutputArgument are always initialized in the routine."""
  315. def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
  316. """Return a new variable.
  317. Parameters
  318. ==========
  319. name : Symbol, MatrixSymbol
  320. The name of this variable. When used for code generation, this
  321. might appear, for example, in the prototype of function in the
  322. argument list.
  323. result_var : Symbol, Indexed
  324. Something that can be used to assign a value to this variable.
  325. Typically the same as `name` but for Indexed this should be e.g.,
  326. "y[i]" whereas `name` should be the Symbol "y".
  327. expr : object
  328. The expression that should be output, typically a SymPy
  329. expression.
  330. datatype : optional
  331. When not given, the data type will be guessed based on the
  332. assumptions on the symbol argument.
  333. dimension : sequence containing tupes, optional
  334. If present, the argument is interpreted as an array, where this
  335. sequence of tuples specifies (lower, upper) bounds for each
  336. index of the array.
  337. precision : int, optional
  338. Controls the precision of floating point constants.
  339. """
  340. Argument.__init__(self, name, datatype, dimensions, precision)
  341. ResultBase.__init__(self, expr, result_var)
  342. def __str__(self):
  343. return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.result_var, self.expr)
  344. __repr__ = __str__
  345. class InOutArgument(Argument, ResultBase):
  346. """InOutArgument are never initialized in the routine."""
  347. def __init__(self, name, result_var, expr, datatype=None, dimensions=None, precision=None):
  348. if not datatype:
  349. datatype = get_default_datatype(expr)
  350. Argument.__init__(self, name, datatype, dimensions, precision)
  351. ResultBase.__init__(self, expr, result_var)
  352. __init__.__doc__ = OutputArgument.__init__.__doc__
  353. def __str__(self):
  354. return "%s(%r, %r, %r)" % (self.__class__.__name__, self.name, self.expr,
  355. self.result_var)
  356. __repr__ = __str__
  357. class Result(Variable, ResultBase):
  358. """An expression for a return value.
  359. The name result is used to avoid conflicts with the reserved word
  360. "return" in the Python language. It is also shorter than ReturnValue.
  361. These may or may not need a name in the destination (e.g., "return(x*y)"
  362. might return a value without ever naming it).
  363. """
  364. def __init__(self, expr, name=None, result_var=None, datatype=None,
  365. dimensions=None, precision=None):
  366. """Initialize a return value.
  367. Parameters
  368. ==========
  369. expr : SymPy expression
  370. name : Symbol, MatrixSymbol, optional
  371. The name of this return variable. When used for code generation,
  372. this might appear, for example, in the prototype of function in a
  373. list of return values. A dummy name is generated if omitted.
  374. result_var : Symbol, Indexed, optional
  375. Something that can be used to assign a value to this variable.
  376. Typically the same as `name` but for Indexed this should be e.g.,
  377. "y[i]" whereas `name` should be the Symbol "y". Defaults to
  378. `name` if omitted.
  379. datatype : optional
  380. When not given, the data type will be guessed based on the
  381. assumptions on the expr argument.
  382. dimension : sequence containing tupes, optional
  383. If present, this variable is interpreted as an array,
  384. where this sequence of tuples specifies (lower, upper)
  385. bounds for each index of the array.
  386. precision : int, optional
  387. Controls the precision of floating point constants.
  388. """
  389. # Basic because it is the base class for all types of expressions
  390. if not isinstance(expr, (Basic, MatrixBase)):
  391. raise TypeError("The first argument must be a SymPy expression.")
  392. if name is None:
  393. name = 'result_%d' % abs(hash(expr))
  394. if datatype is None:
  395. #try to infer data type from the expression
  396. datatype = get_default_datatype(expr)
  397. if isinstance(name, str):
  398. if isinstance(expr, (MatrixBase, MatrixExpr)):
  399. name = MatrixSymbol(name, *expr.shape)
  400. else:
  401. name = Symbol(name)
  402. if result_var is None:
  403. result_var = name
  404. Variable.__init__(self, name, datatype=datatype,
  405. dimensions=dimensions, precision=precision)
  406. ResultBase.__init__(self, expr, result_var)
  407. def __str__(self):
  408. return "%s(%r, %r, %r)" % (self.__class__.__name__, self.expr, self.name,
  409. self.result_var)
  410. __repr__ = __str__
  411. #
  412. # Transformation of routine objects into code
  413. #
  414. class CodeGen:
  415. """Abstract class for the code generators."""
  416. printer = None # will be set to an instance of a CodePrinter subclass
  417. def _indent_code(self, codelines):
  418. return self.printer.indent_code(codelines)
  419. def _printer_method_with_settings(self, method, settings=None, *args, **kwargs):
  420. settings = settings or {}
  421. ori = {k: self.printer._settings[k] for k in settings}
  422. for k, v in settings.items():
  423. self.printer._settings[k] = v
  424. result = getattr(self.printer, method)(*args, **kwargs)
  425. for k, v in ori.items():
  426. self.printer._settings[k] = v
  427. return result
  428. def _get_symbol(self, s):
  429. """Returns the symbol as fcode prints it."""
  430. if self.printer._settings['human']:
  431. expr_str = self.printer.doprint(s)
  432. else:
  433. constants, not_supported, expr_str = self.printer.doprint(s)
  434. if constants or not_supported:
  435. raise ValueError("Failed to print %s" % str(s))
  436. return expr_str.strip()
  437. def __init__(self, project="project", cse=False):
  438. """Initialize a code generator.
  439. Derived classes will offer more options that affect the generated
  440. code.
  441. """
  442. self.project = project
  443. self.cse = cse
  444. def routine(self, name, expr, argument_sequence=None, global_vars=None):
  445. """Creates an Routine object that is appropriate for this language.
  446. This implementation is appropriate for at least C/Fortran. Subclasses
  447. can override this if necessary.
  448. Here, we assume at most one return value (the l-value) which must be
  449. scalar. Additional outputs are OutputArguments (e.g., pointers on
  450. right-hand-side or pass-by-reference). Matrices are always returned
  451. via OutputArguments. If ``argument_sequence`` is None, arguments will
  452. be ordered alphabetically, but with all InputArguments first, and then
  453. OutputArgument and InOutArguments.
  454. """
  455. if self.cse:
  456. from sympy.simplify.cse_main import cse
  457. if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
  458. if not expr:
  459. raise ValueError("No expression given")
  460. for e in expr:
  461. if not e.is_Equality:
  462. raise CodeGenError("Lists of expressions must all be Equalities. {} is not.".format(e))
  463. # create a list of right hand sides and simplify them
  464. rhs = [e.rhs for e in expr]
  465. common, simplified = cse(rhs)
  466. # pack the simplified expressions back up with their left hand sides
  467. expr = [Equality(e.lhs, rhs) for e, rhs in zip(expr, simplified)]
  468. else:
  469. if isinstance(expr, Equality):
  470. common, simplified = cse(expr.rhs) #, ignore=in_out_args)
  471. expr = Equality(expr.lhs, simplified[0])
  472. else:
  473. common, simplified = cse(expr)
  474. expr = simplified
  475. local_vars = [Result(b,a) for a,b in common]
  476. local_symbols = {a for a,_ in common}
  477. local_expressions = Tuple(*[b for _,b in common])
  478. else:
  479. local_expressions = Tuple()
  480. if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
  481. if not expr:
  482. raise ValueError("No expression given")
  483. expressions = Tuple(*expr)
  484. else:
  485. expressions = Tuple(expr)
  486. if self.cse:
  487. if {i.label for i in expressions.atoms(Idx)} != set():
  488. raise CodeGenError("CSE and Indexed expressions do not play well together yet")
  489. else:
  490. # local variables for indexed expressions
  491. local_vars = {i.label for i in expressions.atoms(Idx)}
  492. local_symbols = local_vars
  493. # global variables
  494. global_vars = set() if global_vars is None else set(global_vars)
  495. # symbols that should be arguments
  496. symbols = (expressions.free_symbols | local_expressions.free_symbols) - local_symbols - global_vars
  497. new_symbols = set()
  498. new_symbols.update(symbols)
  499. for symbol in symbols:
  500. if isinstance(symbol, Idx):
  501. new_symbols.remove(symbol)
  502. new_symbols.update(symbol.args[1].free_symbols)
  503. if isinstance(symbol, Indexed):
  504. new_symbols.remove(symbol)
  505. symbols = new_symbols
  506. # Decide whether to use output argument or return value
  507. return_val = []
  508. output_args = []
  509. for expr in expressions:
  510. if isinstance(expr, Equality):
  511. out_arg = expr.lhs
  512. expr = expr.rhs
  513. if isinstance(out_arg, Indexed):
  514. dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
  515. symbol = out_arg.base.label
  516. elif isinstance(out_arg, Symbol):
  517. dims = []
  518. symbol = out_arg
  519. elif isinstance(out_arg, MatrixSymbol):
  520. dims = tuple([ (S.Zero, dim - 1) for dim in out_arg.shape])
  521. symbol = out_arg
  522. else:
  523. raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
  524. "can define output arguments.")
  525. if expr.has(symbol):
  526. output_args.append(
  527. InOutArgument(symbol, out_arg, expr, dimensions=dims))
  528. else:
  529. output_args.append(
  530. OutputArgument(symbol, out_arg, expr, dimensions=dims))
  531. # remove duplicate arguments when they are not local variables
  532. if symbol not in local_vars:
  533. # avoid duplicate arguments
  534. symbols.remove(symbol)
  535. elif isinstance(expr, (ImmutableMatrix, MatrixSlice)):
  536. # Create a "dummy" MatrixSymbol to use as the Output arg
  537. out_arg = MatrixSymbol('out_%s' % abs(hash(expr)), *expr.shape)
  538. dims = tuple([(S.Zero, dim - 1) for dim in out_arg.shape])
  539. output_args.append(
  540. OutputArgument(out_arg, out_arg, expr, dimensions=dims))
  541. else:
  542. return_val.append(Result(expr))
  543. arg_list = []
  544. # setup input argument list
  545. # helper to get dimensions for data for array-like args
  546. def dimensions(s):
  547. return [(S.Zero, dim - 1) for dim in s.shape]
  548. array_symbols = {}
  549. for array in expressions.atoms(Indexed) | local_expressions.atoms(Indexed):
  550. array_symbols[array.base.label] = array
  551. for array in expressions.atoms(MatrixSymbol) | local_expressions.atoms(MatrixSymbol):
  552. array_symbols[array] = array
  553. for symbol in sorted(symbols, key=str):
  554. if symbol in array_symbols:
  555. array = array_symbols[symbol]
  556. metadata = {'dimensions': dimensions(array)}
  557. else:
  558. metadata = {}
  559. arg_list.append(InputArgument(symbol, **metadata))
  560. output_args.sort(key=lambda x: str(x.name))
  561. arg_list.extend(output_args)
  562. if argument_sequence is not None:
  563. # if the user has supplied IndexedBase instances, we'll accept that
  564. new_sequence = []
  565. for arg in argument_sequence:
  566. if isinstance(arg, IndexedBase):
  567. new_sequence.append(arg.label)
  568. else:
  569. new_sequence.append(arg)
  570. argument_sequence = new_sequence
  571. missing = [x for x in arg_list if x.name not in argument_sequence]
  572. if missing:
  573. msg = "Argument list didn't specify: {0} "
  574. msg = msg.format(", ".join([str(m.name) for m in missing]))
  575. raise CodeGenArgumentListError(msg, missing)
  576. # create redundant arguments to produce the requested sequence
  577. name_arg_dict = {x.name: x for x in arg_list}
  578. new_args = []
  579. for symbol in argument_sequence:
  580. try:
  581. new_args.append(name_arg_dict[symbol])
  582. except KeyError:
  583. if isinstance(symbol, (IndexedBase, MatrixSymbol)):
  584. metadata = {'dimensions': dimensions(symbol)}
  585. else:
  586. metadata = {}
  587. new_args.append(InputArgument(symbol, **metadata))
  588. arg_list = new_args
  589. return Routine(name, arg_list, return_val, local_vars, global_vars)
  590. def write(self, routines, prefix, to_files=False, header=True, empty=True):
  591. """Writes all the source code files for the given routines.
  592. The generated source is returned as a list of (filename, contents)
  593. tuples, or is written to files (see below). Each filename consists
  594. of the given prefix, appended with an appropriate extension.
  595. Parameters
  596. ==========
  597. routines : list
  598. A list of Routine instances to be written
  599. prefix : string
  600. The prefix for the output files
  601. to_files : bool, optional
  602. When True, the output is written to files. Otherwise, a list
  603. of (filename, contents) tuples is returned. [default: False]
  604. header : bool, optional
  605. When True, a header comment is included on top of each source
  606. file. [default: True]
  607. empty : bool, optional
  608. When True, empty lines are included to structure the source
  609. files. [default: True]
  610. """
  611. if to_files:
  612. for dump_fn in self.dump_fns:
  613. filename = "%s.%s" % (prefix, dump_fn.extension)
  614. with open(filename, "w") as f:
  615. dump_fn(self, routines, f, prefix, header, empty)
  616. else:
  617. result = []
  618. for dump_fn in self.dump_fns:
  619. filename = "%s.%s" % (prefix, dump_fn.extension)
  620. contents = StringIO()
  621. dump_fn(self, routines, contents, prefix, header, empty)
  622. result.append((filename, contents.getvalue()))
  623. return result
  624. def dump_code(self, routines, f, prefix, header=True, empty=True):
  625. """Write the code by calling language specific methods.
  626. The generated file contains all the definitions of the routines in
  627. low-level code and refers to the header file if appropriate.
  628. Parameters
  629. ==========
  630. routines : list
  631. A list of Routine instances.
  632. f : file-like
  633. Where to write the file.
  634. prefix : string
  635. The filename prefix, used to refer to the proper header file.
  636. Only the basename of the prefix is used.
  637. header : bool, optional
  638. When True, a header comment is included on top of each source
  639. file. [default : True]
  640. empty : bool, optional
  641. When True, empty lines are included to structure the source
  642. files. [default : True]
  643. """
  644. code_lines = self._preprocessor_statements(prefix)
  645. for routine in routines:
  646. if empty:
  647. code_lines.append("\n")
  648. code_lines.extend(self._get_routine_opening(routine))
  649. code_lines.extend(self._declare_arguments(routine))
  650. code_lines.extend(self._declare_globals(routine))
  651. code_lines.extend(self._declare_locals(routine))
  652. if empty:
  653. code_lines.append("\n")
  654. code_lines.extend(self._call_printer(routine))
  655. if empty:
  656. code_lines.append("\n")
  657. code_lines.extend(self._get_routine_ending(routine))
  658. code_lines = self._indent_code(''.join(code_lines))
  659. if header:
  660. code_lines = ''.join(self._get_header() + [code_lines])
  661. if code_lines:
  662. f.write(code_lines)
  663. class CodeGenError(Exception):
  664. pass
  665. class CodeGenArgumentListError(Exception):
  666. @property
  667. def missing_args(self):
  668. return self.args[1]
  669. header_comment = """Code generated with SymPy %(version)s
  670. See http://www.sympy.org/ for more information.
  671. This file is part of '%(project)s'
  672. """
  673. class CCodeGen(CodeGen):
  674. """Generator for C code.
  675. The .write() method inherited from CodeGen will output a code file and
  676. an interface file, <prefix>.c and <prefix>.h respectively.
  677. """
  678. code_extension = "c"
  679. interface_extension = "h"
  680. standard = 'c99'
  681. def __init__(self, project="project", printer=None,
  682. preprocessor_statements=None, cse=False):
  683. super().__init__(project=project, cse=cse)
  684. self.printer = printer or c_code_printers[self.standard.lower()]()
  685. self.preprocessor_statements = preprocessor_statements
  686. if preprocessor_statements is None:
  687. self.preprocessor_statements = ['#include <math.h>']
  688. def _get_header(self):
  689. """Writes a common header for the generated files."""
  690. code_lines = []
  691. code_lines.append("/" + "*"*78 + '\n')
  692. tmp = header_comment % {"version": sympy_version,
  693. "project": self.project}
  694. for line in tmp.splitlines():
  695. code_lines.append(" *%s*\n" % line.center(76))
  696. code_lines.append(" " + "*"*78 + "/\n")
  697. return code_lines
  698. def get_prototype(self, routine):
  699. """Returns a string for the function prototype of the routine.
  700. If the routine has multiple result objects, an CodeGenError is
  701. raised.
  702. See: https://en.wikipedia.org/wiki/Function_prototype
  703. """
  704. if len(routine.results) > 1:
  705. raise CodeGenError("C only supports a single or no return value.")
  706. elif len(routine.results) == 1:
  707. ctype = routine.results[0].get_datatype('C')
  708. else:
  709. ctype = "void"
  710. type_args = []
  711. for arg in routine.arguments:
  712. name = self.printer.doprint(arg.name)
  713. if arg.dimensions or isinstance(arg, ResultBase):
  714. type_args.append((arg.get_datatype('C'), "*%s" % name))
  715. else:
  716. type_args.append((arg.get_datatype('C'), name))
  717. arguments = ", ".join([ "%s %s" % t for t in type_args])
  718. return "%s %s(%s)" % (ctype, routine.name, arguments)
  719. def _preprocessor_statements(self, prefix):
  720. code_lines = []
  721. code_lines.append('#include "{}.h"'.format(os.path.basename(prefix)))
  722. code_lines.extend(self.preprocessor_statements)
  723. code_lines = ['{}\n'.format(l) for l in code_lines]
  724. return code_lines
  725. def _get_routine_opening(self, routine):
  726. prototype = self.get_prototype(routine)
  727. return ["%s {\n" % prototype]
  728. def _declare_arguments(self, routine):
  729. # arguments are declared in prototype
  730. return []
  731. def _declare_globals(self, routine):
  732. # global variables are not explicitly declared within C functions
  733. return []
  734. def _declare_locals(self, routine):
  735. # Compose a list of symbols to be dereferenced in the function
  736. # body. These are the arguments that were passed by a reference
  737. # pointer, excluding arrays.
  738. dereference = []
  739. for arg in routine.arguments:
  740. if isinstance(arg, ResultBase) and not arg.dimensions:
  741. dereference.append(arg.name)
  742. code_lines = []
  743. for result in routine.local_vars:
  744. # local variables that are simple symbols such as those used as indices into
  745. # for loops are defined declared elsewhere.
  746. if not isinstance(result, Result):
  747. continue
  748. if result.name != result.result_var:
  749. raise CodeGen("Result variable and name should match: {}".format(result))
  750. assign_to = result.name
  751. t = result.get_datatype('c')
  752. if isinstance(result.expr, (MatrixBase, MatrixExpr)):
  753. dims = result.expr.shape
  754. code_lines.append("{} {}[{}];\n".format(t, str(assign_to), dims[0]*dims[1]))
  755. prefix = ""
  756. else:
  757. prefix = "const {} ".format(t)
  758. constants, not_c, c_expr = self._printer_method_with_settings(
  759. 'doprint', dict(human=False, dereference=dereference),
  760. result.expr, assign_to=assign_to)
  761. for name, value in sorted(constants, key=str):
  762. code_lines.append("double const %s = %s;\n" % (name, value))
  763. code_lines.append("{}{}\n".format(prefix, c_expr))
  764. return code_lines
  765. def _call_printer(self, routine):
  766. code_lines = []
  767. # Compose a list of symbols to be dereferenced in the function
  768. # body. These are the arguments that were passed by a reference
  769. # pointer, excluding arrays.
  770. dereference = []
  771. for arg in routine.arguments:
  772. if isinstance(arg, ResultBase) and not arg.dimensions:
  773. dereference.append(arg.name)
  774. return_val = None
  775. for result in routine.result_variables:
  776. if isinstance(result, Result):
  777. assign_to = routine.name + "_result"
  778. t = result.get_datatype('c')
  779. code_lines.append("{} {};\n".format(t, str(assign_to)))
  780. return_val = assign_to
  781. else:
  782. assign_to = result.result_var
  783. try:
  784. constants, not_c, c_expr = self._printer_method_with_settings(
  785. 'doprint', dict(human=False, dereference=dereference),
  786. result.expr, assign_to=assign_to)
  787. except AssignmentError:
  788. assign_to = result.result_var
  789. code_lines.append(
  790. "%s %s;\n" % (result.get_datatype('c'), str(assign_to)))
  791. constants, not_c, c_expr = self._printer_method_with_settings(
  792. 'doprint', dict(human=False, dereference=dereference),
  793. result.expr, assign_to=assign_to)
  794. for name, value in sorted(constants, key=str):
  795. code_lines.append("double const %s = %s;\n" % (name, value))
  796. code_lines.append("%s\n" % c_expr)
  797. if return_val:
  798. code_lines.append(" return %s;\n" % return_val)
  799. return code_lines
  800. def _get_routine_ending(self, routine):
  801. return ["}\n"]
  802. def dump_c(self, routines, f, prefix, header=True, empty=True):
  803. self.dump_code(routines, f, prefix, header, empty)
  804. dump_c.extension = code_extension # type: ignore
  805. dump_c.__doc__ = CodeGen.dump_code.__doc__
  806. def dump_h(self, routines, f, prefix, header=True, empty=True):
  807. """Writes the C header file.
  808. This file contains all the function declarations.
  809. Parameters
  810. ==========
  811. routines : list
  812. A list of Routine instances.
  813. f : file-like
  814. Where to write the file.
  815. prefix : string
  816. The filename prefix, used to construct the include guards.
  817. Only the basename of the prefix is used.
  818. header : bool, optional
  819. When True, a header comment is included on top of each source
  820. file. [default : True]
  821. empty : bool, optional
  822. When True, empty lines are included to structure the source
  823. files. [default : True]
  824. """
  825. if header:
  826. print(''.join(self._get_header()), file=f)
  827. guard_name = "%s__%s__H" % (self.project.replace(
  828. " ", "_").upper(), prefix.replace("/", "_").upper())
  829. # include guards
  830. if empty:
  831. print(file=f)
  832. print("#ifndef %s" % guard_name, file=f)
  833. print("#define %s" % guard_name, file=f)
  834. if empty:
  835. print(file=f)
  836. # declaration of the function prototypes
  837. for routine in routines:
  838. prototype = self.get_prototype(routine)
  839. print("%s;" % prototype, file=f)
  840. # end if include guards
  841. if empty:
  842. print(file=f)
  843. print("#endif", file=f)
  844. if empty:
  845. print(file=f)
  846. dump_h.extension = interface_extension # type: ignore
  847. # This list of dump functions is used by CodeGen.write to know which dump
  848. # functions it has to call.
  849. dump_fns = [dump_c, dump_h]
  850. class C89CodeGen(CCodeGen):
  851. standard = 'C89'
  852. class C99CodeGen(CCodeGen):
  853. standard = 'C99'
  854. class FCodeGen(CodeGen):
  855. """Generator for Fortran 95 code
  856. The .write() method inherited from CodeGen will output a code file and
  857. an interface file, <prefix>.f90 and <prefix>.h respectively.
  858. """
  859. code_extension = "f90"
  860. interface_extension = "h"
  861. def __init__(self, project='project', printer=None):
  862. super().__init__(project)
  863. self.printer = printer or FCodePrinter()
  864. def _get_header(self):
  865. """Writes a common header for the generated files."""
  866. code_lines = []
  867. code_lines.append("!" + "*"*78 + '\n')
  868. tmp = header_comment % {"version": sympy_version,
  869. "project": self.project}
  870. for line in tmp.splitlines():
  871. code_lines.append("!*%s*\n" % line.center(76))
  872. code_lines.append("!" + "*"*78 + '\n')
  873. return code_lines
  874. def _preprocessor_statements(self, prefix):
  875. return []
  876. def _get_routine_opening(self, routine):
  877. """Returns the opening statements of the fortran routine."""
  878. code_list = []
  879. if len(routine.results) > 1:
  880. raise CodeGenError(
  881. "Fortran only supports a single or no return value.")
  882. elif len(routine.results) == 1:
  883. result = routine.results[0]
  884. code_list.append(result.get_datatype('fortran'))
  885. code_list.append("function")
  886. else:
  887. code_list.append("subroutine")
  888. args = ", ".join("%s" % self._get_symbol(arg.name)
  889. for arg in routine.arguments)
  890. call_sig = "{}({})\n".format(routine.name, args)
  891. # Fortran 95 requires all lines be less than 132 characters, so wrap
  892. # this line before appending.
  893. call_sig = ' &\n'.join(textwrap.wrap(call_sig,
  894. width=60,
  895. break_long_words=False)) + '\n'
  896. code_list.append(call_sig)
  897. code_list = [' '.join(code_list)]
  898. code_list.append('implicit none\n')
  899. return code_list
  900. def _declare_arguments(self, routine):
  901. # argument type declarations
  902. code_list = []
  903. array_list = []
  904. scalar_list = []
  905. for arg in routine.arguments:
  906. if isinstance(arg, InputArgument):
  907. typeinfo = "%s, intent(in)" % arg.get_datatype('fortran')
  908. elif isinstance(arg, InOutArgument):
  909. typeinfo = "%s, intent(inout)" % arg.get_datatype('fortran')
  910. elif isinstance(arg, OutputArgument):
  911. typeinfo = "%s, intent(out)" % arg.get_datatype('fortran')
  912. else:
  913. raise CodeGenError("Unknown Argument type: %s" % type(arg))
  914. fprint = self._get_symbol
  915. if arg.dimensions:
  916. # fortran arrays start at 1
  917. dimstr = ", ".join(["%s:%s" % (
  918. fprint(dim[0] + 1), fprint(dim[1] + 1))
  919. for dim in arg.dimensions])
  920. typeinfo += ", dimension(%s)" % dimstr
  921. array_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
  922. else:
  923. scalar_list.append("%s :: %s\n" % (typeinfo, fprint(arg.name)))
  924. # scalars first, because they can be used in array declarations
  925. code_list.extend(scalar_list)
  926. code_list.extend(array_list)
  927. return code_list
  928. def _declare_globals(self, routine):
  929. # Global variables not explicitly declared within Fortran 90 functions.
  930. # Note: a future F77 mode may need to generate "common" blocks.
  931. return []
  932. def _declare_locals(self, routine):
  933. code_list = []
  934. for var in sorted(routine.local_vars, key=str):
  935. typeinfo = get_default_datatype(var)
  936. code_list.append("%s :: %s\n" % (
  937. typeinfo.fname, self._get_symbol(var)))
  938. return code_list
  939. def _get_routine_ending(self, routine):
  940. """Returns the closing statements of the fortran routine."""
  941. if len(routine.results) == 1:
  942. return ["end function\n"]
  943. else:
  944. return ["end subroutine\n"]
  945. def get_interface(self, routine):
  946. """Returns a string for the function interface.
  947. The routine should have a single result object, which can be None.
  948. If the routine has multiple result objects, a CodeGenError is
  949. raised.
  950. See: https://en.wikipedia.org/wiki/Function_prototype
  951. """
  952. prototype = [ "interface\n" ]
  953. prototype.extend(self._get_routine_opening(routine))
  954. prototype.extend(self._declare_arguments(routine))
  955. prototype.extend(self._get_routine_ending(routine))
  956. prototype.append("end interface\n")
  957. return "".join(prototype)
  958. def _call_printer(self, routine):
  959. declarations = []
  960. code_lines = []
  961. for result in routine.result_variables:
  962. if isinstance(result, Result):
  963. assign_to = routine.name
  964. elif isinstance(result, (OutputArgument, InOutArgument)):
  965. assign_to = result.result_var
  966. constants, not_fortran, f_expr = self._printer_method_with_settings(
  967. 'doprint', dict(human=False, source_format='free', standard=95),
  968. result.expr, assign_to=assign_to)
  969. for obj, v in sorted(constants, key=str):
  970. t = get_default_datatype(obj)
  971. declarations.append(
  972. "%s, parameter :: %s = %s\n" % (t.fname, obj, v))
  973. for obj in sorted(not_fortran, key=str):
  974. t = get_default_datatype(obj)
  975. if isinstance(obj, Function):
  976. name = obj.func
  977. else:
  978. name = obj
  979. declarations.append("%s :: %s\n" % (t.fname, name))
  980. code_lines.append("%s\n" % f_expr)
  981. return declarations + code_lines
  982. def _indent_code(self, codelines):
  983. return self._printer_method_with_settings(
  984. 'indent_code', dict(human=False, source_format='free'), codelines)
  985. def dump_f95(self, routines, f, prefix, header=True, empty=True):
  986. # check that symbols are unique with ignorecase
  987. for r in routines:
  988. lowercase = {str(x).lower() for x in r.variables}
  989. orig_case = {str(x) for x in r.variables}
  990. if len(lowercase) < len(orig_case):
  991. raise CodeGenError("Fortran ignores case. Got symbols: %s" %
  992. (", ".join([str(var) for var in r.variables])))
  993. self.dump_code(routines, f, prefix, header, empty)
  994. dump_f95.extension = code_extension # type: ignore
  995. dump_f95.__doc__ = CodeGen.dump_code.__doc__
  996. def dump_h(self, routines, f, prefix, header=True, empty=True):
  997. """Writes the interface to a header file.
  998. This file contains all the function declarations.
  999. Parameters
  1000. ==========
  1001. routines : list
  1002. A list of Routine instances.
  1003. f : file-like
  1004. Where to write the file.
  1005. prefix : string
  1006. The filename prefix.
  1007. header : bool, optional
  1008. When True, a header comment is included on top of each source
  1009. file. [default : True]
  1010. empty : bool, optional
  1011. When True, empty lines are included to structure the source
  1012. files. [default : True]
  1013. """
  1014. if header:
  1015. print(''.join(self._get_header()), file=f)
  1016. if empty:
  1017. print(file=f)
  1018. # declaration of the function prototypes
  1019. for routine in routines:
  1020. prototype = self.get_interface(routine)
  1021. f.write(prototype)
  1022. if empty:
  1023. print(file=f)
  1024. dump_h.extension = interface_extension # type: ignore
  1025. # This list of dump functions is used by CodeGen.write to know which dump
  1026. # functions it has to call.
  1027. dump_fns = [dump_f95, dump_h]
  1028. class JuliaCodeGen(CodeGen):
  1029. """Generator for Julia code.
  1030. The .write() method inherited from CodeGen will output a code file
  1031. <prefix>.jl.
  1032. """
  1033. code_extension = "jl"
  1034. def __init__(self, project='project', printer=None):
  1035. super().__init__(project)
  1036. self.printer = printer or JuliaCodePrinter()
  1037. def routine(self, name, expr, argument_sequence, global_vars):
  1038. """Specialized Routine creation for Julia."""
  1039. if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
  1040. if not expr:
  1041. raise ValueError("No expression given")
  1042. expressions = Tuple(*expr)
  1043. else:
  1044. expressions = Tuple(expr)
  1045. # local variables
  1046. local_vars = {i.label for i in expressions.atoms(Idx)}
  1047. # global variables
  1048. global_vars = set() if global_vars is None else set(global_vars)
  1049. # symbols that should be arguments
  1050. old_symbols = expressions.free_symbols - local_vars - global_vars
  1051. symbols = set()
  1052. for s in old_symbols:
  1053. if isinstance(s, Idx):
  1054. symbols.update(s.args[1].free_symbols)
  1055. elif not isinstance(s, Indexed):
  1056. symbols.add(s)
  1057. # Julia supports multiple return values
  1058. return_vals = []
  1059. output_args = []
  1060. for (i, expr) in enumerate(expressions):
  1061. if isinstance(expr, Equality):
  1062. out_arg = expr.lhs
  1063. expr = expr.rhs
  1064. symbol = out_arg
  1065. if isinstance(out_arg, Indexed):
  1066. dims = tuple([ (S.One, dim) for dim in out_arg.shape])
  1067. symbol = out_arg.base.label
  1068. output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
  1069. if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
  1070. raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
  1071. "can define output arguments.")
  1072. return_vals.append(Result(expr, name=symbol, result_var=out_arg))
  1073. if not expr.has(symbol):
  1074. # this is a pure output: remove from the symbols list, so
  1075. # it doesn't become an input.
  1076. symbols.remove(symbol)
  1077. else:
  1078. # we have no name for this output
  1079. return_vals.append(Result(expr, name='out%d' % (i+1)))
  1080. # setup input argument list
  1081. output_args.sort(key=lambda x: str(x.name))
  1082. arg_list = list(output_args)
  1083. array_symbols = {}
  1084. for array in expressions.atoms(Indexed):
  1085. array_symbols[array.base.label] = array
  1086. for array in expressions.atoms(MatrixSymbol):
  1087. array_symbols[array] = array
  1088. for symbol in sorted(symbols, key=str):
  1089. arg_list.append(InputArgument(symbol))
  1090. if argument_sequence is not None:
  1091. # if the user has supplied IndexedBase instances, we'll accept that
  1092. new_sequence = []
  1093. for arg in argument_sequence:
  1094. if isinstance(arg, IndexedBase):
  1095. new_sequence.append(arg.label)
  1096. else:
  1097. new_sequence.append(arg)
  1098. argument_sequence = new_sequence
  1099. missing = [x for x in arg_list if x.name not in argument_sequence]
  1100. if missing:
  1101. msg = "Argument list didn't specify: {0} "
  1102. msg = msg.format(", ".join([str(m.name) for m in missing]))
  1103. raise CodeGenArgumentListError(msg, missing)
  1104. # create redundant arguments to produce the requested sequence
  1105. name_arg_dict = {x.name: x for x in arg_list}
  1106. new_args = []
  1107. for symbol in argument_sequence:
  1108. try:
  1109. new_args.append(name_arg_dict[symbol])
  1110. except KeyError:
  1111. new_args.append(InputArgument(symbol))
  1112. arg_list = new_args
  1113. return Routine(name, arg_list, return_vals, local_vars, global_vars)
  1114. def _get_header(self):
  1115. """Writes a common header for the generated files."""
  1116. code_lines = []
  1117. tmp = header_comment % {"version": sympy_version,
  1118. "project": self.project}
  1119. for line in tmp.splitlines():
  1120. if line == '':
  1121. code_lines.append("#\n")
  1122. else:
  1123. code_lines.append("# %s\n" % line)
  1124. return code_lines
  1125. def _preprocessor_statements(self, prefix):
  1126. return []
  1127. def _get_routine_opening(self, routine):
  1128. """Returns the opening statements of the routine."""
  1129. code_list = []
  1130. code_list.append("function ")
  1131. # Inputs
  1132. args = []
  1133. for i, arg in enumerate(routine.arguments):
  1134. if isinstance(arg, OutputArgument):
  1135. raise CodeGenError("Julia: invalid argument of type %s" %
  1136. str(type(arg)))
  1137. if isinstance(arg, (InputArgument, InOutArgument)):
  1138. args.append("%s" % self._get_symbol(arg.name))
  1139. args = ", ".join(args)
  1140. code_list.append("%s(%s)\n" % (routine.name, args))
  1141. code_list = [ "".join(code_list) ]
  1142. return code_list
  1143. def _declare_arguments(self, routine):
  1144. return []
  1145. def _declare_globals(self, routine):
  1146. return []
  1147. def _declare_locals(self, routine):
  1148. return []
  1149. def _get_routine_ending(self, routine):
  1150. outs = []
  1151. for result in routine.results:
  1152. if isinstance(result, Result):
  1153. # Note: name not result_var; want `y` not `y[i]` for Indexed
  1154. s = self._get_symbol(result.name)
  1155. else:
  1156. raise CodeGenError("unexpected object in Routine results")
  1157. outs.append(s)
  1158. return ["return " + ", ".join(outs) + "\nend\n"]
  1159. def _call_printer(self, routine):
  1160. declarations = []
  1161. code_lines = []
  1162. for i, result in enumerate(routine.results):
  1163. if isinstance(result, Result):
  1164. assign_to = result.result_var
  1165. else:
  1166. raise CodeGenError("unexpected object in Routine results")
  1167. constants, not_supported, jl_expr = self._printer_method_with_settings(
  1168. 'doprint', dict(human=False), result.expr, assign_to=assign_to)
  1169. for obj, v in sorted(constants, key=str):
  1170. declarations.append(
  1171. "%s = %s\n" % (obj, v))
  1172. for obj in sorted(not_supported, key=str):
  1173. if isinstance(obj, Function):
  1174. name = obj.func
  1175. else:
  1176. name = obj
  1177. declarations.append(
  1178. "# unsupported: %s\n" % (name))
  1179. code_lines.append("%s\n" % (jl_expr))
  1180. return declarations + code_lines
  1181. def _indent_code(self, codelines):
  1182. # Note that indenting seems to happen twice, first
  1183. # statement-by-statement by JuliaPrinter then again here.
  1184. p = JuliaCodePrinter({'human': False})
  1185. return p.indent_code(codelines)
  1186. def dump_jl(self, routines, f, prefix, header=True, empty=True):
  1187. self.dump_code(routines, f, prefix, header, empty)
  1188. dump_jl.extension = code_extension # type: ignore
  1189. dump_jl.__doc__ = CodeGen.dump_code.__doc__
  1190. # This list of dump functions is used by CodeGen.write to know which dump
  1191. # functions it has to call.
  1192. dump_fns = [dump_jl]
  1193. class OctaveCodeGen(CodeGen):
  1194. """Generator for Octave code.
  1195. The .write() method inherited from CodeGen will output a code file
  1196. <prefix>.m.
  1197. Octave .m files usually contain one function. That function name should
  1198. match the filename (``prefix``). If you pass multiple ``name_expr`` pairs,
  1199. the latter ones are presumed to be private functions accessed by the
  1200. primary function.
  1201. You should only pass inputs to ``argument_sequence``: outputs are ordered
  1202. according to their order in ``name_expr``.
  1203. """
  1204. code_extension = "m"
  1205. def __init__(self, project='project', printer=None):
  1206. super().__init__(project)
  1207. self.printer = printer or OctaveCodePrinter()
  1208. def routine(self, name, expr, argument_sequence, global_vars):
  1209. """Specialized Routine creation for Octave."""
  1210. # FIXME: this is probably general enough for other high-level
  1211. # languages, perhaps its the C/Fortran one that is specialized!
  1212. if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
  1213. if not expr:
  1214. raise ValueError("No expression given")
  1215. expressions = Tuple(*expr)
  1216. else:
  1217. expressions = Tuple(expr)
  1218. # local variables
  1219. local_vars = {i.label for i in expressions.atoms(Idx)}
  1220. # global variables
  1221. global_vars = set() if global_vars is None else set(global_vars)
  1222. # symbols that should be arguments
  1223. old_symbols = expressions.free_symbols - local_vars - global_vars
  1224. symbols = set()
  1225. for s in old_symbols:
  1226. if isinstance(s, Idx):
  1227. symbols.update(s.args[1].free_symbols)
  1228. elif not isinstance(s, Indexed):
  1229. symbols.add(s)
  1230. # Octave supports multiple return values
  1231. return_vals = []
  1232. for (i, expr) in enumerate(expressions):
  1233. if isinstance(expr, Equality):
  1234. out_arg = expr.lhs
  1235. expr = expr.rhs
  1236. symbol = out_arg
  1237. if isinstance(out_arg, Indexed):
  1238. symbol = out_arg.base.label
  1239. if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
  1240. raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
  1241. "can define output arguments.")
  1242. return_vals.append(Result(expr, name=symbol, result_var=out_arg))
  1243. if not expr.has(symbol):
  1244. # this is a pure output: remove from the symbols list, so
  1245. # it doesn't become an input.
  1246. symbols.remove(symbol)
  1247. else:
  1248. # we have no name for this output
  1249. return_vals.append(Result(expr, name='out%d' % (i+1)))
  1250. # setup input argument list
  1251. arg_list = []
  1252. array_symbols = {}
  1253. for array in expressions.atoms(Indexed):
  1254. array_symbols[array.base.label] = array
  1255. for array in expressions.atoms(MatrixSymbol):
  1256. array_symbols[array] = array
  1257. for symbol in sorted(symbols, key=str):
  1258. arg_list.append(InputArgument(symbol))
  1259. if argument_sequence is not None:
  1260. # if the user has supplied IndexedBase instances, we'll accept that
  1261. new_sequence = []
  1262. for arg in argument_sequence:
  1263. if isinstance(arg, IndexedBase):
  1264. new_sequence.append(arg.label)
  1265. else:
  1266. new_sequence.append(arg)
  1267. argument_sequence = new_sequence
  1268. missing = [x for x in arg_list if x.name not in argument_sequence]
  1269. if missing:
  1270. msg = "Argument list didn't specify: {0} "
  1271. msg = msg.format(", ".join([str(m.name) for m in missing]))
  1272. raise CodeGenArgumentListError(msg, missing)
  1273. # create redundant arguments to produce the requested sequence
  1274. name_arg_dict = {x.name: x for x in arg_list}
  1275. new_args = []
  1276. for symbol in argument_sequence:
  1277. try:
  1278. new_args.append(name_arg_dict[symbol])
  1279. except KeyError:
  1280. new_args.append(InputArgument(symbol))
  1281. arg_list = new_args
  1282. return Routine(name, arg_list, return_vals, local_vars, global_vars)
  1283. def _get_header(self):
  1284. """Writes a common header for the generated files."""
  1285. code_lines = []
  1286. tmp = header_comment % {"version": sympy_version,
  1287. "project": self.project}
  1288. for line in tmp.splitlines():
  1289. if line == '':
  1290. code_lines.append("%\n")
  1291. else:
  1292. code_lines.append("%% %s\n" % line)
  1293. return code_lines
  1294. def _preprocessor_statements(self, prefix):
  1295. return []
  1296. def _get_routine_opening(self, routine):
  1297. """Returns the opening statements of the routine."""
  1298. code_list = []
  1299. code_list.append("function ")
  1300. # Outputs
  1301. outs = []
  1302. for i, result in enumerate(routine.results):
  1303. if isinstance(result, Result):
  1304. # Note: name not result_var; want `y` not `y(i)` for Indexed
  1305. s = self._get_symbol(result.name)
  1306. else:
  1307. raise CodeGenError("unexpected object in Routine results")
  1308. outs.append(s)
  1309. if len(outs) > 1:
  1310. code_list.append("[" + (", ".join(outs)) + "]")
  1311. else:
  1312. code_list.append("".join(outs))
  1313. code_list.append(" = ")
  1314. # Inputs
  1315. args = []
  1316. for i, arg in enumerate(routine.arguments):
  1317. if isinstance(arg, (OutputArgument, InOutArgument)):
  1318. raise CodeGenError("Octave: invalid argument of type %s" %
  1319. str(type(arg)))
  1320. if isinstance(arg, InputArgument):
  1321. args.append("%s" % self._get_symbol(arg.name))
  1322. args = ", ".join(args)
  1323. code_list.append("%s(%s)\n" % (routine.name, args))
  1324. code_list = [ "".join(code_list) ]
  1325. return code_list
  1326. def _declare_arguments(self, routine):
  1327. return []
  1328. def _declare_globals(self, routine):
  1329. if not routine.global_vars:
  1330. return []
  1331. s = " ".join(sorted([self._get_symbol(g) for g in routine.global_vars]))
  1332. return ["global " + s + "\n"]
  1333. def _declare_locals(self, routine):
  1334. return []
  1335. def _get_routine_ending(self, routine):
  1336. return ["end\n"]
  1337. def _call_printer(self, routine):
  1338. declarations = []
  1339. code_lines = []
  1340. for i, result in enumerate(routine.results):
  1341. if isinstance(result, Result):
  1342. assign_to = result.result_var
  1343. else:
  1344. raise CodeGenError("unexpected object in Routine results")
  1345. constants, not_supported, oct_expr = self._printer_method_with_settings(
  1346. 'doprint', dict(human=False), result.expr, assign_to=assign_to)
  1347. for obj, v in sorted(constants, key=str):
  1348. declarations.append(
  1349. " %s = %s; %% constant\n" % (obj, v))
  1350. for obj in sorted(not_supported, key=str):
  1351. if isinstance(obj, Function):
  1352. name = obj.func
  1353. else:
  1354. name = obj
  1355. declarations.append(
  1356. " %% unsupported: %s\n" % (name))
  1357. code_lines.append("%s\n" % (oct_expr))
  1358. return declarations + code_lines
  1359. def _indent_code(self, codelines):
  1360. return self._printer_method_with_settings(
  1361. 'indent_code', dict(human=False), codelines)
  1362. def dump_m(self, routines, f, prefix, header=True, empty=True, inline=True):
  1363. # Note used to call self.dump_code() but we need more control for header
  1364. code_lines = self._preprocessor_statements(prefix)
  1365. for i, routine in enumerate(routines):
  1366. if i > 0:
  1367. if empty:
  1368. code_lines.append("\n")
  1369. code_lines.extend(self._get_routine_opening(routine))
  1370. if i == 0:
  1371. if routine.name != prefix:
  1372. raise ValueError('Octave function name should match prefix')
  1373. if header:
  1374. code_lines.append("%" + prefix.upper() +
  1375. " Autogenerated by SymPy\n")
  1376. code_lines.append(''.join(self._get_header()))
  1377. code_lines.extend(self._declare_arguments(routine))
  1378. code_lines.extend(self._declare_globals(routine))
  1379. code_lines.extend(self._declare_locals(routine))
  1380. if empty:
  1381. code_lines.append("\n")
  1382. code_lines.extend(self._call_printer(routine))
  1383. if empty:
  1384. code_lines.append("\n")
  1385. code_lines.extend(self._get_routine_ending(routine))
  1386. code_lines = self._indent_code(''.join(code_lines))
  1387. if code_lines:
  1388. f.write(code_lines)
  1389. dump_m.extension = code_extension # type: ignore
  1390. dump_m.__doc__ = CodeGen.dump_code.__doc__
  1391. # This list of dump functions is used by CodeGen.write to know which dump
  1392. # functions it has to call.
  1393. dump_fns = [dump_m]
  1394. class RustCodeGen(CodeGen):
  1395. """Generator for Rust code.
  1396. The .write() method inherited from CodeGen will output a code file
  1397. <prefix>.rs
  1398. """
  1399. code_extension = "rs"
  1400. def __init__(self, project="project", printer=None):
  1401. super().__init__(project=project)
  1402. self.printer = printer or RustCodePrinter()
  1403. def routine(self, name, expr, argument_sequence, global_vars):
  1404. """Specialized Routine creation for Rust."""
  1405. if is_sequence(expr) and not isinstance(expr, (MatrixBase, MatrixExpr)):
  1406. if not expr:
  1407. raise ValueError("No expression given")
  1408. expressions = Tuple(*expr)
  1409. else:
  1410. expressions = Tuple(expr)
  1411. # local variables
  1412. local_vars = {i.label for i in expressions.atoms(Idx)}
  1413. # global variables
  1414. global_vars = set() if global_vars is None else set(global_vars)
  1415. # symbols that should be arguments
  1416. symbols = expressions.free_symbols - local_vars - global_vars - expressions.atoms(Indexed)
  1417. # Rust supports multiple return values
  1418. return_vals = []
  1419. output_args = []
  1420. for (i, expr) in enumerate(expressions):
  1421. if isinstance(expr, Equality):
  1422. out_arg = expr.lhs
  1423. expr = expr.rhs
  1424. symbol = out_arg
  1425. if isinstance(out_arg, Indexed):
  1426. dims = tuple([ (S.One, dim) for dim in out_arg.shape])
  1427. symbol = out_arg.base.label
  1428. output_args.append(InOutArgument(symbol, out_arg, expr, dimensions=dims))
  1429. if not isinstance(out_arg, (Indexed, Symbol, MatrixSymbol)):
  1430. raise CodeGenError("Only Indexed, Symbol, or MatrixSymbol "
  1431. "can define output arguments.")
  1432. return_vals.append(Result(expr, name=symbol, result_var=out_arg))
  1433. if not expr.has(symbol):
  1434. # this is a pure output: remove from the symbols list, so
  1435. # it doesn't become an input.
  1436. symbols.remove(symbol)
  1437. else:
  1438. # we have no name for this output
  1439. return_vals.append(Result(expr, name='out%d' % (i+1)))
  1440. # setup input argument list
  1441. output_args.sort(key=lambda x: str(x.name))
  1442. arg_list = list(output_args)
  1443. array_symbols = {}
  1444. for array in expressions.atoms(Indexed):
  1445. array_symbols[array.base.label] = array
  1446. for array in expressions.atoms(MatrixSymbol):
  1447. array_symbols[array] = array
  1448. for symbol in sorted(symbols, key=str):
  1449. arg_list.append(InputArgument(symbol))
  1450. if argument_sequence is not None:
  1451. # if the user has supplied IndexedBase instances, we'll accept that
  1452. new_sequence = []
  1453. for arg in argument_sequence:
  1454. if isinstance(arg, IndexedBase):
  1455. new_sequence.append(arg.label)
  1456. else:
  1457. new_sequence.append(arg)
  1458. argument_sequence = new_sequence
  1459. missing = [x for x in arg_list if x.name not in argument_sequence]
  1460. if missing:
  1461. msg = "Argument list didn't specify: {0} "
  1462. msg = msg.format(", ".join([str(m.name) for m in missing]))
  1463. raise CodeGenArgumentListError(msg, missing)
  1464. # create redundant arguments to produce the requested sequence
  1465. name_arg_dict = {x.name: x for x in arg_list}
  1466. new_args = []
  1467. for symbol in argument_sequence:
  1468. try:
  1469. new_args.append(name_arg_dict[symbol])
  1470. except KeyError:
  1471. new_args.append(InputArgument(symbol))
  1472. arg_list = new_args
  1473. return Routine(name, arg_list, return_vals, local_vars, global_vars)
  1474. def _get_header(self):
  1475. """Writes a common header for the generated files."""
  1476. code_lines = []
  1477. code_lines.append("/*\n")
  1478. tmp = header_comment % {"version": sympy_version,
  1479. "project": self.project}
  1480. for line in tmp.splitlines():
  1481. code_lines.append((" *%s" % line.center(76)).rstrip() + "\n")
  1482. code_lines.append(" */\n")
  1483. return code_lines
  1484. def get_prototype(self, routine):
  1485. """Returns a string for the function prototype of the routine.
  1486. If the routine has multiple result objects, an CodeGenError is
  1487. raised.
  1488. See: https://en.wikipedia.org/wiki/Function_prototype
  1489. """
  1490. results = [i.get_datatype('Rust') for i in routine.results]
  1491. if len(results) == 1:
  1492. rstype = " -> " + results[0]
  1493. elif len(routine.results) > 1:
  1494. rstype = " -> (" + ", ".join(results) + ")"
  1495. else:
  1496. rstype = ""
  1497. type_args = []
  1498. for arg in routine.arguments:
  1499. name = self.printer.doprint(arg.name)
  1500. if arg.dimensions or isinstance(arg, ResultBase):
  1501. type_args.append(("*%s" % name, arg.get_datatype('Rust')))
  1502. else:
  1503. type_args.append((name, arg.get_datatype('Rust')))
  1504. arguments = ", ".join([ "%s: %s" % t for t in type_args])
  1505. return "fn %s(%s)%s" % (routine.name, arguments, rstype)
  1506. def _preprocessor_statements(self, prefix):
  1507. code_lines = []
  1508. # code_lines.append("use std::f64::consts::*;\n")
  1509. return code_lines
  1510. def _get_routine_opening(self, routine):
  1511. prototype = self.get_prototype(routine)
  1512. return ["%s {\n" % prototype]
  1513. def _declare_arguments(self, routine):
  1514. # arguments are declared in prototype
  1515. return []
  1516. def _declare_globals(self, routine):
  1517. # global variables are not explicitly declared within C functions
  1518. return []
  1519. def _declare_locals(self, routine):
  1520. # loop variables are declared in loop statement
  1521. return []
  1522. def _call_printer(self, routine):
  1523. code_lines = []
  1524. declarations = []
  1525. returns = []
  1526. # Compose a list of symbols to be dereferenced in the function
  1527. # body. These are the arguments that were passed by a reference
  1528. # pointer, excluding arrays.
  1529. dereference = []
  1530. for arg in routine.arguments:
  1531. if isinstance(arg, ResultBase) and not arg.dimensions:
  1532. dereference.append(arg.name)
  1533. for i, result in enumerate(routine.results):
  1534. if isinstance(result, Result):
  1535. assign_to = result.result_var
  1536. returns.append(str(result.result_var))
  1537. else:
  1538. raise CodeGenError("unexpected object in Routine results")
  1539. constants, not_supported, rs_expr = self._printer_method_with_settings(
  1540. 'doprint', dict(human=False), result.expr, assign_to=assign_to)
  1541. for name, value in sorted(constants, key=str):
  1542. declarations.append("const %s: f64 = %s;\n" % (name, value))
  1543. for obj in sorted(not_supported, key=str):
  1544. if isinstance(obj, Function):
  1545. name = obj.func
  1546. else:
  1547. name = obj
  1548. declarations.append("// unsupported: %s\n" % (name))
  1549. code_lines.append("let %s\n" % rs_expr);
  1550. if len(returns) > 1:
  1551. returns = ['(' + ', '.join(returns) + ')']
  1552. returns.append('\n')
  1553. return declarations + code_lines + returns
  1554. def _get_routine_ending(self, routine):
  1555. return ["}\n"]
  1556. def dump_rs(self, routines, f, prefix, header=True, empty=True):
  1557. self.dump_code(routines, f, prefix, header, empty)
  1558. dump_rs.extension = code_extension # type: ignore
  1559. dump_rs.__doc__ = CodeGen.dump_code.__doc__
  1560. # This list of dump functions is used by CodeGen.write to know which dump
  1561. # functions it has to call.
  1562. dump_fns = [dump_rs]
  1563. def get_code_generator(language, project=None, standard=None, printer = None):
  1564. if language == 'C':
  1565. if standard is None:
  1566. pass
  1567. elif standard.lower() == 'c89':
  1568. language = 'C89'
  1569. elif standard.lower() == 'c99':
  1570. language = 'C99'
  1571. CodeGenClass = {"C": CCodeGen, "C89": C89CodeGen, "C99": C99CodeGen,
  1572. "F95": FCodeGen, "JULIA": JuliaCodeGen,
  1573. "OCTAVE": OctaveCodeGen,
  1574. "RUST": RustCodeGen}.get(language.upper())
  1575. if CodeGenClass is None:
  1576. raise ValueError("Language '%s' is not supported." % language)
  1577. return CodeGenClass(project, printer)
  1578. #
  1579. # Friendly functions
  1580. #
  1581. def codegen(name_expr, language=None, prefix=None, project="project",
  1582. to_files=False, header=True, empty=True, argument_sequence=None,
  1583. global_vars=None, standard=None, code_gen=None, printer = None):
  1584. """Generate source code for expressions in a given language.
  1585. Parameters
  1586. ==========
  1587. name_expr : tuple, or list of tuples
  1588. A single (name, expression) tuple or a list of (name, expression)
  1589. tuples. Each tuple corresponds to a routine. If the expression is
  1590. an equality (an instance of class Equality) the left hand side is
  1591. considered an output argument. If expression is an iterable, then
  1592. the routine will have multiple outputs.
  1593. language : string,
  1594. A string that indicates the source code language. This is case
  1595. insensitive. Currently, 'C', 'F95' and 'Octave' are supported.
  1596. 'Octave' generates code compatible with both Octave and Matlab.
  1597. prefix : string, optional
  1598. A prefix for the names of the files that contain the source code.
  1599. Language-dependent suffixes will be appended. If omitted, the name
  1600. of the first name_expr tuple is used.
  1601. project : string, optional
  1602. A project name, used for making unique preprocessor instructions.
  1603. [default: "project"]
  1604. to_files : bool, optional
  1605. When True, the code will be written to one or more files with the
  1606. given prefix, otherwise strings with the names and contents of
  1607. these files are returned. [default: False]
  1608. header : bool, optional
  1609. When True, a header is written on top of each source file.
  1610. [default: True]
  1611. empty : bool, optional
  1612. When True, empty lines are used to structure the code.
  1613. [default: True]
  1614. argument_sequence : iterable, optional
  1615. Sequence of arguments for the routine in a preferred order. A
  1616. CodeGenError is raised if required arguments are missing.
  1617. Redundant arguments are used without warning. If omitted,
  1618. arguments will be ordered alphabetically, but with all input
  1619. arguments first, and then output or in-out arguments.
  1620. global_vars : iterable, optional
  1621. Sequence of global variables used by the routine. Variables
  1622. listed here will not show up as function arguments.
  1623. standard : string
  1624. code_gen : CodeGen instance
  1625. An instance of a CodeGen subclass. Overrides ``language``.
  1626. Examples
  1627. ========
  1628. >>> from sympy.utilities.codegen import codegen
  1629. >>> from sympy.abc import x, y, z
  1630. >>> [(c_name, c_code), (h_name, c_header)] = codegen(
  1631. ... ("f", x+y*z), "C89", "test", header=False, empty=False)
  1632. >>> print(c_name)
  1633. test.c
  1634. >>> print(c_code)
  1635. #include "test.h"
  1636. #include <math.h>
  1637. double f(double x, double y, double z) {
  1638. double f_result;
  1639. f_result = x + y*z;
  1640. return f_result;
  1641. }
  1642. <BLANKLINE>
  1643. >>> print(h_name)
  1644. test.h
  1645. >>> print(c_header)
  1646. #ifndef PROJECT__TEST__H
  1647. #define PROJECT__TEST__H
  1648. double f(double x, double y, double z);
  1649. #endif
  1650. <BLANKLINE>
  1651. Another example using Equality objects to give named outputs. Here the
  1652. filename (prefix) is taken from the first (name, expr) pair.
  1653. >>> from sympy.abc import f, g
  1654. >>> from sympy import Eq
  1655. >>> [(c_name, c_code), (h_name, c_header)] = codegen(
  1656. ... [("myfcn", x + y), ("fcn2", [Eq(f, 2*x), Eq(g, y)])],
  1657. ... "C99", header=False, empty=False)
  1658. >>> print(c_name)
  1659. myfcn.c
  1660. >>> print(c_code)
  1661. #include "myfcn.h"
  1662. #include <math.h>
  1663. double myfcn(double x, double y) {
  1664. double myfcn_result;
  1665. myfcn_result = x + y;
  1666. return myfcn_result;
  1667. }
  1668. void fcn2(double x, double y, double *f, double *g) {
  1669. (*f) = 2*x;
  1670. (*g) = y;
  1671. }
  1672. <BLANKLINE>
  1673. If the generated function(s) will be part of a larger project where various
  1674. global variables have been defined, the 'global_vars' option can be used
  1675. to remove the specified variables from the function signature
  1676. >>> from sympy.utilities.codegen import codegen
  1677. >>> from sympy.abc import x, y, z
  1678. >>> [(f_name, f_code), header] = codegen(
  1679. ... ("f", x+y*z), "F95", header=False, empty=False,
  1680. ... argument_sequence=(x, y), global_vars=(z,))
  1681. >>> print(f_code)
  1682. REAL*8 function f(x, y)
  1683. implicit none
  1684. REAL*8, intent(in) :: x
  1685. REAL*8, intent(in) :: y
  1686. f = x + y*z
  1687. end function
  1688. <BLANKLINE>
  1689. """
  1690. # Initialize the code generator.
  1691. if language is None:
  1692. if code_gen is None:
  1693. raise ValueError("Need either language or code_gen")
  1694. else:
  1695. if code_gen is not None:
  1696. raise ValueError("You cannot specify both language and code_gen.")
  1697. code_gen = get_code_generator(language, project, standard, printer)
  1698. if isinstance(name_expr[0], str):
  1699. # single tuple is given, turn it into a singleton list with a tuple.
  1700. name_expr = [name_expr]
  1701. if prefix is None:
  1702. prefix = name_expr[0][0]
  1703. # Construct Routines appropriate for this code_gen from (name, expr) pairs.
  1704. routines = []
  1705. for name, expr in name_expr:
  1706. routines.append(code_gen.routine(name, expr, argument_sequence,
  1707. global_vars))
  1708. # Write the code.
  1709. return code_gen.write(routines, prefix, to_files, header, empty)
  1710. def make_routine(name, expr, argument_sequence=None,
  1711. global_vars=None, language="F95"):
  1712. """A factory that makes an appropriate Routine from an expression.
  1713. Parameters
  1714. ==========
  1715. name : string
  1716. The name of this routine in the generated code.
  1717. expr : expression or list/tuple of expressions
  1718. A SymPy expression that the Routine instance will represent. If
  1719. given a list or tuple of expressions, the routine will be
  1720. considered to have multiple return values and/or output arguments.
  1721. argument_sequence : list or tuple, optional
  1722. List arguments for the routine in a preferred order. If omitted,
  1723. the results are language dependent, for example, alphabetical order
  1724. or in the same order as the given expressions.
  1725. global_vars : iterable, optional
  1726. Sequence of global variables used by the routine. Variables
  1727. listed here will not show up as function arguments.
  1728. language : string, optional
  1729. Specify a target language. The Routine itself should be
  1730. language-agnostic but the precise way one is created, error
  1731. checking, etc depend on the language. [default: "F95"].
  1732. Notes
  1733. =====
  1734. A decision about whether to use output arguments or return values is made
  1735. depending on both the language and the particular mathematical expressions.
  1736. For an expression of type Equality, the left hand side is typically made
  1737. into an OutputArgument (or perhaps an InOutArgument if appropriate).
  1738. Otherwise, typically, the calculated expression is made a return values of
  1739. the routine.
  1740. Examples
  1741. ========
  1742. >>> from sympy.utilities.codegen import make_routine
  1743. >>> from sympy.abc import x, y, f, g
  1744. >>> from sympy import Eq
  1745. >>> r = make_routine('test', [Eq(f, 2*x), Eq(g, x + y)])
  1746. >>> [arg.result_var for arg in r.results]
  1747. []
  1748. >>> [arg.name for arg in r.arguments]
  1749. [x, y, f, g]
  1750. >>> [arg.name for arg in r.result_variables]
  1751. [f, g]
  1752. >>> r.local_vars
  1753. set()
  1754. Another more complicated example with a mixture of specified and
  1755. automatically-assigned names. Also has Matrix output.
  1756. >>> from sympy import Matrix
  1757. >>> r = make_routine('fcn', [x*y, Eq(f, 1), Eq(g, x + g), Matrix([[x, 2]])])
  1758. >>> [arg.result_var for arg in r.results] # doctest: +SKIP
  1759. [result_5397460570204848505]
  1760. >>> [arg.expr for arg in r.results]
  1761. [x*y]
  1762. >>> [arg.name for arg in r.arguments] # doctest: +SKIP
  1763. [x, y, f, g, out_8598435338387848786]
  1764. We can examine the various arguments more closely:
  1765. >>> from sympy.utilities.codegen import (InputArgument, OutputArgument,
  1766. ... InOutArgument)
  1767. >>> [a.name for a in r.arguments if isinstance(a, InputArgument)]
  1768. [x, y]
  1769. >>> [a.name for a in r.arguments if isinstance(a, OutputArgument)] # doctest: +SKIP
  1770. [f, out_8598435338387848786]
  1771. >>> [a.expr for a in r.arguments if isinstance(a, OutputArgument)]
  1772. [1, Matrix([[x, 2]])]
  1773. >>> [a.name for a in r.arguments if isinstance(a, InOutArgument)]
  1774. [g]
  1775. >>> [a.expr for a in r.arguments if isinstance(a, InOutArgument)]
  1776. [g + x]
  1777. """
  1778. # initialize a new code generator
  1779. code_gen = get_code_generator(language)
  1780. return code_gen.routine(name, expr, argument_sequence, global_vars)