numpy.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499
  1. from sympy.core import S
  2. from .pycode import PythonCodePrinter, _known_functions_math, _print_known_const, _print_known_func, _unpack_integral_limits
  3. from .codeprinter import CodePrinter
  4. _not_in_numpy = 'erf erfc factorial gamma loggamma'.split()
  5. _in_numpy = [(k, v) for k, v in _known_functions_math.items() if k not in _not_in_numpy]
  6. _known_functions_numpy = dict(_in_numpy, **{
  7. 'acos': 'arccos',
  8. 'acosh': 'arccosh',
  9. 'asin': 'arcsin',
  10. 'asinh': 'arcsinh',
  11. 'atan': 'arctan',
  12. 'atan2': 'arctan2',
  13. 'atanh': 'arctanh',
  14. 'exp2': 'exp2',
  15. 'sign': 'sign',
  16. 'logaddexp': 'logaddexp',
  17. 'logaddexp2': 'logaddexp2',
  18. })
  19. _known_constants_numpy = {
  20. 'Exp1': 'e',
  21. 'Pi': 'pi',
  22. 'EulerGamma': 'euler_gamma',
  23. 'NaN': 'nan',
  24. 'Infinity': 'PINF',
  25. 'NegativeInfinity': 'NINF'
  26. }
  27. _numpy_known_functions = {k: 'numpy.' + v for k, v in _known_functions_numpy.items()}
  28. _numpy_known_constants = {k: 'numpy.' + v for k, v in _known_constants_numpy.items()}
  29. class NumPyPrinter(PythonCodePrinter):
  30. """
  31. Numpy printer which handles vectorized piecewise functions,
  32. logical operators, etc.
  33. """
  34. _module = 'numpy'
  35. _kf = _numpy_known_functions
  36. _kc = _numpy_known_constants
  37. def __init__(self, settings=None):
  38. """
  39. `settings` is passed to CodePrinter.__init__()
  40. `module` specifies the array module to use, currently 'NumPy' or 'CuPy'
  41. """
  42. self.language = "Python with {}".format(self._module)
  43. self.printmethod = "_{}code".format(self._module)
  44. self._kf = {**PythonCodePrinter._kf, **self._kf}
  45. super().__init__(settings=settings)
  46. def _print_seq(self, seq):
  47. "General sequence printer: converts to tuple"
  48. # Print tuples here instead of lists because numba supports
  49. # tuples in nopython mode.
  50. delimiter=', '
  51. return '({},)'.format(delimiter.join(self._print(item) for item in seq))
  52. def _print_MatMul(self, expr):
  53. "Matrix multiplication printer"
  54. if expr.as_coeff_matrices()[0] is not S.One:
  55. expr_list = expr.as_coeff_matrices()[1]+[(expr.as_coeff_matrices()[0])]
  56. return '({})'.format(').dot('.join(self._print(i) for i in expr_list))
  57. return '({})'.format(').dot('.join(self._print(i) for i in expr.args))
  58. def _print_MatPow(self, expr):
  59. "Matrix power printer"
  60. return '{}({}, {})'.format(self._module_format(self._module + '.linalg.matrix_power'),
  61. self._print(expr.args[0]), self._print(expr.args[1]))
  62. def _print_Inverse(self, expr):
  63. "Matrix inverse printer"
  64. return '{}({})'.format(self._module_format(self._module + '.linalg.inv'),
  65. self._print(expr.args[0]))
  66. def _print_DotProduct(self, expr):
  67. # DotProduct allows any shape order, but numpy.dot does matrix
  68. # multiplication, so we have to make sure it gets 1 x n by n x 1.
  69. arg1, arg2 = expr.args
  70. if arg1.shape[0] != 1:
  71. arg1 = arg1.T
  72. if arg2.shape[1] != 1:
  73. arg2 = arg2.T
  74. return "%s(%s, %s)" % (self._module_format(self._module + '.dot'),
  75. self._print(arg1),
  76. self._print(arg2))
  77. def _print_MatrixSolve(self, expr):
  78. return "%s(%s, %s)" % (self._module_format(self._module + '.linalg.solve'),
  79. self._print(expr.matrix),
  80. self._print(expr.vector))
  81. def _print_ZeroMatrix(self, expr):
  82. return '{}({})'.format(self._module_format(self._module + '.zeros'),
  83. self._print(expr.shape))
  84. def _print_OneMatrix(self, expr):
  85. return '{}({})'.format(self._module_format(self._module + '.ones'),
  86. self._print(expr.shape))
  87. def _print_FunctionMatrix(self, expr):
  88. from sympy.core.function import Lambda
  89. from sympy.abc import i, j
  90. lamda = expr.lamda
  91. if not isinstance(lamda, Lambda):
  92. lamda = Lambda((i, j), lamda(i, j))
  93. return '{}(lambda {}: {}, {})'.format(self._module_format(self._module + '.fromfunction'),
  94. ', '.join(self._print(arg) for arg in lamda.args[0]),
  95. self._print(lamda.args[1]), self._print(expr.shape))
  96. def _print_HadamardProduct(self, expr):
  97. func = self._module_format(self._module + '.multiply')
  98. return ''.join('{}({}, '.format(func, self._print(arg)) \
  99. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  100. ')' * (len(expr.args) - 1))
  101. def _print_KroneckerProduct(self, expr):
  102. func = self._module_format(self._module + '.kron')
  103. return ''.join('{}({}, '.format(func, self._print(arg)) \
  104. for arg in expr.args[:-1]) + "{}{}".format(self._print(expr.args[-1]),
  105. ')' * (len(expr.args) - 1))
  106. def _print_Adjoint(self, expr):
  107. return '{}({}({}))'.format(
  108. self._module_format(self._module + '.conjugate'),
  109. self._module_format(self._module + '.transpose'),
  110. self._print(expr.args[0]))
  111. def _print_DiagonalOf(self, expr):
  112. vect = '{}({})'.format(
  113. self._module_format(self._module + '.diag'),
  114. self._print(expr.arg))
  115. return '{}({}, (-1, 1))'.format(
  116. self._module_format(self._module + '.reshape'), vect)
  117. def _print_DiagMatrix(self, expr):
  118. return '{}({})'.format(self._module_format(self._module + '.diagflat'),
  119. self._print(expr.args[0]))
  120. def _print_DiagonalMatrix(self, expr):
  121. return '{}({}, {}({}, {}))'.format(self._module_format(self._module + '.multiply'),
  122. self._print(expr.arg), self._module_format(self._module + '.eye'),
  123. self._print(expr.shape[0]), self._print(expr.shape[1]))
  124. def _print_Piecewise(self, expr):
  125. "Piecewise function printer"
  126. from sympy.logic.boolalg import ITE, simplify_logic
  127. def print_cond(cond):
  128. """ Problem having an ITE in the cond. """
  129. if cond.has(ITE):
  130. return self._print(simplify_logic(cond))
  131. else:
  132. return self._print(cond)
  133. exprs = '[{}]'.format(','.join(self._print(arg.expr) for arg in expr.args))
  134. conds = '[{}]'.format(','.join(print_cond(arg.cond) for arg in expr.args))
  135. # If [default_value, True] is a (expr, cond) sequence in a Piecewise object
  136. # it will behave the same as passing the 'default' kwarg to select()
  137. # *as long as* it is the last element in expr.args.
  138. # If this is not the case, it may be triggered prematurely.
  139. return '{}({}, {}, default={})'.format(
  140. self._module_format(self._module + '.select'), conds, exprs,
  141. self._print(S.NaN))
  142. def _print_Relational(self, expr):
  143. "Relational printer for Equality and Unequality"
  144. op = {
  145. '==' :'equal',
  146. '!=' :'not_equal',
  147. '<' :'less',
  148. '<=' :'less_equal',
  149. '>' :'greater',
  150. '>=' :'greater_equal',
  151. }
  152. if expr.rel_op in op:
  153. lhs = self._print(expr.lhs)
  154. rhs = self._print(expr.rhs)
  155. return '{op}({lhs}, {rhs})'.format(op=self._module_format(self._module + '.'+op[expr.rel_op]),
  156. lhs=lhs, rhs=rhs)
  157. return super()._print_Relational(expr)
  158. def _print_And(self, expr):
  159. "Logical And printer"
  160. # We have to override LambdaPrinter because it uses Python 'and' keyword.
  161. # If LambdaPrinter didn't define it, we could use StrPrinter's
  162. # version of the function and add 'logical_and' to NUMPY_TRANSLATIONS.
  163. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_and'), ','.join(self._print(i) for i in expr.args))
  164. def _print_Or(self, expr):
  165. "Logical Or printer"
  166. # We have to override LambdaPrinter because it uses Python 'or' keyword.
  167. # If LambdaPrinter didn't define it, we could use StrPrinter's
  168. # version of the function and add 'logical_or' to NUMPY_TRANSLATIONS.
  169. return '{}.reduce(({}))'.format(self._module_format(self._module + '.logical_or'), ','.join(self._print(i) for i in expr.args))
  170. def _print_Not(self, expr):
  171. "Logical Not printer"
  172. # We have to override LambdaPrinter because it uses Python 'not' keyword.
  173. # If LambdaPrinter didn't define it, we would still have to define our
  174. # own because StrPrinter doesn't define it.
  175. return '{}({})'.format(self._module_format(self._module + '.logical_not'), ','.join(self._print(i) for i in expr.args))
  176. def _print_Pow(self, expr, rational=False):
  177. # XXX Workaround for negative integer power error
  178. from sympy.core.power import Pow
  179. if expr.exp.is_integer and expr.exp.is_negative:
  180. expr = Pow(expr.base, expr.exp.evalf(), evaluate=False)
  181. return self._hprint_Pow(expr, rational=rational, sqrt=self._module + '.sqrt')
  182. def _print_Min(self, expr):
  183. return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amin'), ','.join(self._print(i) for i in expr.args))
  184. def _print_Max(self, expr):
  185. return '{}(({}), axis=0)'.format(self._module_format(self._module + '.amax'), ','.join(self._print(i) for i in expr.args))
  186. def _print_arg(self, expr):
  187. return "%s(%s)" % (self._module_format(self._module + '.angle'), self._print(expr.args[0]))
  188. def _print_im(self, expr):
  189. return "%s(%s)" % (self._module_format(self._module + '.imag'), self._print(expr.args[0]))
  190. def _print_Mod(self, expr):
  191. return "%s(%s)" % (self._module_format(self._module + '.mod'), ', '.join(
  192. map(lambda arg: self._print(arg), expr.args)))
  193. def _print_re(self, expr):
  194. return "%s(%s)" % (self._module_format(self._module + '.real'), self._print(expr.args[0]))
  195. def _print_sinc(self, expr):
  196. return "%s(%s)" % (self._module_format(self._module + '.sinc'), self._print(expr.args[0]/S.Pi))
  197. def _print_MatrixBase(self, expr):
  198. func = self.known_functions.get(expr.__class__.__name__, None)
  199. if func is None:
  200. func = self._module_format(self._module + '.array')
  201. return "%s(%s)" % (func, self._print(expr.tolist()))
  202. def _print_Identity(self, expr):
  203. shape = expr.shape
  204. if all(dim.is_Integer for dim in shape):
  205. return "%s(%s)" % (self._module_format(self._module + '.eye'), self._print(expr.shape[0]))
  206. else:
  207. raise NotImplementedError("Symbolic matrix dimensions are not yet supported for identity matrices")
  208. def _print_BlockMatrix(self, expr):
  209. return '{}({})'.format(self._module_format(self._module + '.block'),
  210. self._print(expr.args[0].tolist()))
  211. def _print_ArrayTensorProduct(self, expr):
  212. array_list = [j for i, arg in enumerate(expr.args) for j in
  213. (self._print(arg), "[%i, %i]" % (2*i, 2*i+1))]
  214. return "%s(%s)" % (self._module_format(self._module + '.einsum'), ", ".join(array_list))
  215. def _print_ArrayContraction(self, expr):
  216. from ..tensor.array.expressions.array_expressions import ArrayTensorProduct
  217. base = expr.expr
  218. contraction_indices = expr.contraction_indices
  219. if not contraction_indices:
  220. return self._print(base)
  221. if isinstance(base, ArrayTensorProduct):
  222. counter = 0
  223. d = {j: min(i) for i in contraction_indices for j in i}
  224. indices = []
  225. for rank_arg in base.subranks:
  226. lindices = []
  227. for i in range(rank_arg):
  228. if counter in d:
  229. lindices.append(d[counter])
  230. else:
  231. lindices.append(counter)
  232. counter += 1
  233. indices.append(lindices)
  234. elems = ["%s, %s" % (self._print(arg), ind) for arg, ind in zip(base.args, indices)]
  235. return "%s(%s)" % (
  236. self._module_format(self._module + '.einsum'),
  237. ", ".join(elems)
  238. )
  239. raise NotImplementedError()
  240. def _print_ArrayDiagonal(self, expr):
  241. diagonal_indices = list(expr.diagonal_indices)
  242. if len(diagonal_indices) > 1:
  243. # TODO: this should be handled in sympy.codegen.array_utils,
  244. # possibly by creating the possibility of unfolding the
  245. # ArrayDiagonal object into nested ones. Same reasoning for
  246. # the array contraction.
  247. raise NotImplementedError
  248. if len(diagonal_indices[0]) != 2:
  249. raise NotImplementedError
  250. return "%s(%s, 0, axis1=%s, axis2=%s)" % (
  251. self._module_format("numpy.diagonal"),
  252. self._print(expr.expr),
  253. diagonal_indices[0][0],
  254. diagonal_indices[0][1],
  255. )
  256. def _print_PermuteDims(self, expr):
  257. return "%s(%s, %s)" % (
  258. self._module_format("numpy.transpose"),
  259. self._print(expr.expr),
  260. self._print(expr.permutation.array_form),
  261. )
  262. def _print_ArrayAdd(self, expr):
  263. return self._expand_fold_binary_op(self._module + '.add', expr.args)
  264. def _print_NDimArray(self, expr):
  265. if len(expr.shape) == 1:
  266. return self._module + '.array(' + self._print(expr.args[0]) + ')'
  267. if len(expr.shape) == 2:
  268. return self._print(expr.tomatrix())
  269. # Should be possible to extend to more dimensions
  270. return CodePrinter._print_not_supported(self, expr)
  271. _print_lowergamma = CodePrinter._print_not_supported
  272. _print_uppergamma = CodePrinter._print_not_supported
  273. _print_fresnelc = CodePrinter._print_not_supported
  274. _print_fresnels = CodePrinter._print_not_supported
  275. for func in _numpy_known_functions:
  276. setattr(NumPyPrinter, f'_print_{func}', _print_known_func)
  277. for const in _numpy_known_constants:
  278. setattr(NumPyPrinter, f'_print_{const}', _print_known_const)
  279. _known_functions_scipy_special = {
  280. 'erf': 'erf',
  281. 'erfc': 'erfc',
  282. 'besselj': 'jv',
  283. 'bessely': 'yv',
  284. 'besseli': 'iv',
  285. 'besselk': 'kv',
  286. 'cosm1': 'cosm1',
  287. 'factorial': 'factorial',
  288. 'gamma': 'gamma',
  289. 'loggamma': 'gammaln',
  290. 'digamma': 'psi',
  291. 'RisingFactorial': 'poch',
  292. 'jacobi': 'eval_jacobi',
  293. 'gegenbauer': 'eval_gegenbauer',
  294. 'chebyshevt': 'eval_chebyt',
  295. 'chebyshevu': 'eval_chebyu',
  296. 'legendre': 'eval_legendre',
  297. 'hermite': 'eval_hermite',
  298. 'laguerre': 'eval_laguerre',
  299. 'assoc_laguerre': 'eval_genlaguerre',
  300. 'beta': 'beta',
  301. 'LambertW' : 'lambertw',
  302. }
  303. _known_constants_scipy_constants = {
  304. 'GoldenRatio': 'golden_ratio',
  305. 'Pi': 'pi',
  306. }
  307. _scipy_known_functions = {k : "scipy.special." + v for k, v in _known_functions_scipy_special.items()}
  308. _scipy_known_constants = {k : "scipy.constants." + v for k, v in _known_constants_scipy_constants.items()}
  309. class SciPyPrinter(NumPyPrinter):
  310. _kf = {**NumPyPrinter._kf, **_scipy_known_functions}
  311. _kc = {**NumPyPrinter._kc, **_scipy_known_constants}
  312. def __init__(self, settings=None):
  313. super().__init__(settings=settings)
  314. self.language = "Python with SciPy and NumPy"
  315. def _print_SparseRepMatrix(self, expr):
  316. i, j, data = [], [], []
  317. for (r, c), v in expr.todok().items():
  318. i.append(r)
  319. j.append(c)
  320. data.append(v)
  321. return "{name}(({data}, ({i}, {j})), shape={shape})".format(
  322. name=self._module_format('scipy.sparse.coo_matrix'),
  323. data=data, i=i, j=j, shape=expr.shape
  324. )
  325. _print_ImmutableSparseMatrix = _print_SparseRepMatrix
  326. # SciPy's lpmv has a different order of arguments from assoc_legendre
  327. def _print_assoc_legendre(self, expr):
  328. return "{0}({2}, {1}, {3})".format(
  329. self._module_format('scipy.special.lpmv'),
  330. self._print(expr.args[0]),
  331. self._print(expr.args[1]),
  332. self._print(expr.args[2]))
  333. def _print_lowergamma(self, expr):
  334. return "{0}({2})*{1}({2}, {3})".format(
  335. self._module_format('scipy.special.gamma'),
  336. self._module_format('scipy.special.gammainc'),
  337. self._print(expr.args[0]),
  338. self._print(expr.args[1]))
  339. def _print_uppergamma(self, expr):
  340. return "{0}({2})*{1}({2}, {3})".format(
  341. self._module_format('scipy.special.gamma'),
  342. self._module_format('scipy.special.gammaincc'),
  343. self._print(expr.args[0]),
  344. self._print(expr.args[1]))
  345. def _print_betainc(self, expr):
  346. betainc = self._module_format('scipy.special.betainc')
  347. beta = self._module_format('scipy.special.beta')
  348. args = [self._print(arg) for arg in expr.args]
  349. return f"({betainc}({args[0]}, {args[1]}, {args[3]}) - {betainc}({args[0]}, {args[1]}, {args[2]})) \
  350. * {beta}({args[0]}, {args[1]})"
  351. def _print_betainc_regularized(self, expr):
  352. return "{0}({1}, {2}, {4}) - {0}({1}, {2}, {3})".format(
  353. self._module_format('scipy.special.betainc'),
  354. self._print(expr.args[0]),
  355. self._print(expr.args[1]),
  356. self._print(expr.args[2]),
  357. self._print(expr.args[3]))
  358. def _print_fresnels(self, expr):
  359. return "{}({})[0]".format(
  360. self._module_format("scipy.special.fresnel"),
  361. self._print(expr.args[0]))
  362. def _print_fresnelc(self, expr):
  363. return "{}({})[1]".format(
  364. self._module_format("scipy.special.fresnel"),
  365. self._print(expr.args[0]))
  366. def _print_airyai(self, expr):
  367. return "{}({})[0]".format(
  368. self._module_format("scipy.special.airy"),
  369. self._print(expr.args[0]))
  370. def _print_airyaiprime(self, expr):
  371. return "{}({})[1]".format(
  372. self._module_format("scipy.special.airy"),
  373. self._print(expr.args[0]))
  374. def _print_airybi(self, expr):
  375. return "{}({})[2]".format(
  376. self._module_format("scipy.special.airy"),
  377. self._print(expr.args[0]))
  378. def _print_airybiprime(self, expr):
  379. return "{}({})[3]".format(
  380. self._module_format("scipy.special.airy"),
  381. self._print(expr.args[0]))
  382. def _print_Integral(self, e):
  383. integration_vars, limits = _unpack_integral_limits(e)
  384. if len(limits) == 1:
  385. # nicer (but not necessary) to prefer quad over nquad for 1D case
  386. module_str = self._module_format("scipy.integrate.quad")
  387. limit_str = "%s, %s" % tuple(map(self._print, limits[0]))
  388. else:
  389. module_str = self._module_format("scipy.integrate.nquad")
  390. limit_str = "({})".format(", ".join(
  391. "(%s, %s)" % tuple(map(self._print, l)) for l in limits))
  392. return "{}(lambda {}: {}, {})[0]".format(
  393. module_str,
  394. ", ".join(map(self._print, integration_vars)),
  395. self._print(e.args[0]),
  396. limit_str)
  397. for func in _scipy_known_functions:
  398. setattr(SciPyPrinter, f'_print_{func}', _print_known_func)
  399. for const in _scipy_known_constants:
  400. setattr(SciPyPrinter, f'_print_{const}', _print_known_const)
  401. _cupy_known_functions = {k : "cupy." + v for k, v in _known_functions_numpy.items()}
  402. _cupy_known_constants = {k : "cupy." + v for k, v in _known_constants_numpy.items()}
  403. class CuPyPrinter(NumPyPrinter):
  404. """
  405. CuPy printer which handles vectorized piecewise functions,
  406. logical operators, etc.
  407. """
  408. _module = 'cupy'
  409. _kf = _cupy_known_functions
  410. _kc = _cupy_known_constants
  411. def __init__(self, settings=None):
  412. super().__init__(settings=settings)
  413. for func in _cupy_known_functions:
  414. setattr(CuPyPrinter, f'_print_{func}', _print_known_func)
  415. for const in _cupy_known_constants:
  416. setattr(CuPyPrinter, f'_print_{const}', _print_known_const)