gammasimp.py 18 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503
  1. from sympy.core import Function, S, Mul, Pow, Add
  2. from sympy.core.sorting import ordered, default_sort_key
  3. from sympy.core.function import expand_func
  4. from sympy.core.symbol import Dummy
  5. from sympy.functions import gamma, sqrt, sin
  6. from sympy.polys import factor, cancel
  7. from sympy.utilities.iterables import sift, uniq
  8. def gammasimp(expr):
  9. r"""
  10. Simplify expressions with gamma functions.
  11. Explanation
  12. ===========
  13. This function takes as input an expression containing gamma
  14. functions or functions that can be rewritten in terms of gamma
  15. functions and tries to minimize the number of those functions and
  16. reduce the size of their arguments.
  17. The algorithm works by rewriting all gamma functions as expressions
  18. involving rising factorials (Pochhammer symbols) and applies
  19. recurrence relations and other transformations applicable to rising
  20. factorials, to reduce their arguments, possibly letting the resulting
  21. rising factorial to cancel. Rising factorials with the second argument
  22. being an integer are expanded into polynomial forms and finally all
  23. other rising factorial are rewritten in terms of gamma functions.
  24. Then the following two steps are performed.
  25. 1. Reduce the number of gammas by applying the reflection theorem
  26. gamma(x)*gamma(1-x) == pi/sin(pi*x).
  27. 2. Reduce the number of gammas by applying the multiplication theorem
  28. gamma(x)*gamma(x+1/n)*...*gamma(x+(n-1)/n) == C*gamma(n*x).
  29. It then reduces the number of prefactors by absorbing them into gammas
  30. where possible and expands gammas with rational argument.
  31. All transformation rules can be found (or were derived from) here:
  32. .. [1] http://functions.wolfram.com/GammaBetaErf/Pochhammer/17/01/02/
  33. .. [2] http://functions.wolfram.com/GammaBetaErf/Pochhammer/27/01/0005/
  34. Examples
  35. ========
  36. >>> from sympy.simplify import gammasimp
  37. >>> from sympy import gamma, Symbol
  38. >>> from sympy.abc import x
  39. >>> n = Symbol('n', integer = True)
  40. >>> gammasimp(gamma(x)/gamma(x - 3))
  41. (x - 3)*(x - 2)*(x - 1)
  42. >>> gammasimp(gamma(n + 3))
  43. gamma(n + 3)
  44. """
  45. expr = expr.rewrite(gamma)
  46. # compute_ST will be looking for Functions and we don't want
  47. # it looking for non-gamma functions: issue 22606
  48. # so we mask free, non-gamma functions
  49. f = expr.atoms(Function)
  50. # take out gammas
  51. gammas = {i for i in f if isinstance(i, gamma)}
  52. if not gammas:
  53. return expr # avoid side effects like factoring
  54. f -= gammas
  55. # keep only those without bound symbols
  56. f = f & expr.as_dummy().atoms(Function)
  57. if f:
  58. dum, fun, simp = zip(*[
  59. (Dummy(), fi, fi.func(*[
  60. _gammasimp(a, as_comb=False) for a in fi.args]))
  61. for fi in ordered(f)])
  62. d = expr.xreplace(dict(zip(fun, dum)))
  63. return _gammasimp(d, as_comb=False).xreplace(dict(zip(dum, simp)))
  64. return _gammasimp(expr, as_comb=False)
  65. def _gammasimp(expr, as_comb):
  66. """
  67. Helper function for gammasimp and combsimp.
  68. Explanation
  69. ===========
  70. Simplifies expressions written in terms of gamma function. If
  71. as_comb is True, it tries to preserve integer arguments. See
  72. docstring of gammasimp for more information. This was part of
  73. combsimp() in combsimp.py.
  74. """
  75. expr = expr.replace(gamma,
  76. lambda n: _rf(1, (n - 1).expand()))
  77. if as_comb:
  78. expr = expr.replace(_rf,
  79. lambda a, b: gamma(b + 1))
  80. else:
  81. expr = expr.replace(_rf,
  82. lambda a, b: gamma(a + b)/gamma(a))
  83. def rule_gamma(expr, level=0):
  84. """ Simplify products of gamma functions further. """
  85. if expr.is_Atom:
  86. return expr
  87. def gamma_rat(x):
  88. # helper to simplify ratios of gammas
  89. was = x.count(gamma)
  90. xx = x.replace(gamma, lambda n: _rf(1, (n - 1).expand()
  91. ).replace(_rf, lambda a, b: gamma(a + b)/gamma(a)))
  92. if xx.count(gamma) < was:
  93. x = xx
  94. return x
  95. def gamma_factor(x):
  96. # return True if there is a gamma factor in shallow args
  97. if isinstance(x, gamma):
  98. return True
  99. if x.is_Add or x.is_Mul:
  100. return any(gamma_factor(xi) for xi in x.args)
  101. if x.is_Pow and (x.exp.is_integer or x.base.is_positive):
  102. return gamma_factor(x.base)
  103. return False
  104. # recursion step
  105. if level == 0:
  106. expr = expr.func(*[rule_gamma(x, level + 1) for x in expr.args])
  107. level += 1
  108. if not expr.is_Mul:
  109. return expr
  110. # non-commutative step
  111. if level == 1:
  112. args, nc = expr.args_cnc()
  113. if not args:
  114. return expr
  115. if nc:
  116. return rule_gamma(Mul._from_args(args), level + 1)*Mul._from_args(nc)
  117. level += 1
  118. # pure gamma handling, not factor absorption
  119. if level == 2:
  120. T, F = sift(expr.args, gamma_factor, binary=True)
  121. gamma_ind = Mul(*F)
  122. d = Mul(*T)
  123. nd, dd = d.as_numer_denom()
  124. for ipass in range(2):
  125. args = list(ordered(Mul.make_args(nd)))
  126. for i, ni in enumerate(args):
  127. if ni.is_Add:
  128. ni, dd = Add(*[
  129. rule_gamma(gamma_rat(a/dd), level + 1) for a in ni.args]
  130. ).as_numer_denom()
  131. args[i] = ni
  132. if not dd.has(gamma):
  133. break
  134. nd = Mul(*args)
  135. if ipass == 0 and not gamma_factor(nd):
  136. break
  137. nd, dd = dd, nd # now process in reversed order
  138. expr = gamma_ind*nd/dd
  139. if not (expr.is_Mul and (gamma_factor(dd) or gamma_factor(nd))):
  140. return expr
  141. level += 1
  142. # iteration until constant
  143. if level == 3:
  144. while True:
  145. was = expr
  146. expr = rule_gamma(expr, 4)
  147. if expr == was:
  148. return expr
  149. numer_gammas = []
  150. denom_gammas = []
  151. numer_others = []
  152. denom_others = []
  153. def explicate(p):
  154. if p is S.One:
  155. return None, []
  156. b, e = p.as_base_exp()
  157. if e.is_Integer:
  158. if isinstance(b, gamma):
  159. return True, [b.args[0]]*e
  160. else:
  161. return False, [b]*e
  162. else:
  163. return False, [p]
  164. newargs = list(ordered(expr.args))
  165. while newargs:
  166. n, d = newargs.pop().as_numer_denom()
  167. isg, l = explicate(n)
  168. if isg:
  169. numer_gammas.extend(l)
  170. elif isg is False:
  171. numer_others.extend(l)
  172. isg, l = explicate(d)
  173. if isg:
  174. denom_gammas.extend(l)
  175. elif isg is False:
  176. denom_others.extend(l)
  177. # =========== level 2 work: pure gamma manipulation =========
  178. if not as_comb:
  179. # Try to reduce the number of gamma factors by applying the
  180. # reflection formula gamma(x)*gamma(1-x) = pi/sin(pi*x)
  181. for gammas, numer, denom in [(
  182. numer_gammas, numer_others, denom_others),
  183. (denom_gammas, denom_others, numer_others)]:
  184. new = []
  185. while gammas:
  186. g1 = gammas.pop()
  187. if g1.is_integer:
  188. new.append(g1)
  189. continue
  190. for i, g2 in enumerate(gammas):
  191. n = g1 + g2 - 1
  192. if not n.is_Integer:
  193. continue
  194. numer.append(S.Pi)
  195. denom.append(sin(S.Pi*g1))
  196. gammas.pop(i)
  197. if n > 0:
  198. for k in range(n):
  199. numer.append(1 - g1 + k)
  200. elif n < 0:
  201. for k in range(-n):
  202. denom.append(-g1 - k)
  203. break
  204. else:
  205. new.append(g1)
  206. # /!\ updating IN PLACE
  207. gammas[:] = new
  208. # Try to reduce the number of gammas by using the duplication
  209. # theorem to cancel an upper and lower: gamma(2*s)/gamma(s) =
  210. # 2**(2*s + 1)/(4*sqrt(pi))*gamma(s + 1/2). Although this could
  211. # be done with higher argument ratios like gamma(3*x)/gamma(x),
  212. # this would not reduce the number of gammas as in this case.
  213. for ng, dg, no, do in [(numer_gammas, denom_gammas, numer_others,
  214. denom_others),
  215. (denom_gammas, numer_gammas, denom_others,
  216. numer_others)]:
  217. while True:
  218. for x in ng:
  219. for y in dg:
  220. n = x - 2*y
  221. if n.is_Integer:
  222. break
  223. else:
  224. continue
  225. break
  226. else:
  227. break
  228. ng.remove(x)
  229. dg.remove(y)
  230. if n > 0:
  231. for k in range(n):
  232. no.append(2*y + k)
  233. elif n < 0:
  234. for k in range(-n):
  235. do.append(2*y - 1 - k)
  236. ng.append(y + S.Half)
  237. no.append(2**(2*y - 1))
  238. do.append(sqrt(S.Pi))
  239. # Try to reduce the number of gamma factors by applying the
  240. # multiplication theorem (used when n gammas with args differing
  241. # by 1/n mod 1 are encountered).
  242. #
  243. # run of 2 with args differing by 1/2
  244. #
  245. # >>> gammasimp(gamma(x)*gamma(x+S.Half))
  246. # 2*sqrt(2)*2**(-2*x - 1/2)*sqrt(pi)*gamma(2*x)
  247. #
  248. # run of 3 args differing by 1/3 (mod 1)
  249. #
  250. # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(2)/3))
  251. # 6*3**(-3*x - 1/2)*pi*gamma(3*x)
  252. # >>> gammasimp(gamma(x)*gamma(x+S(1)/3)*gamma(x+S(5)/3))
  253. # 2*3**(-3*x - 1/2)*pi*(3*x + 2)*gamma(3*x)
  254. #
  255. def _run(coeffs):
  256. # find runs in coeffs such that the difference in terms (mod 1)
  257. # of t1, t2, ..., tn is 1/n
  258. u = list(uniq(coeffs))
  259. for i in range(len(u)):
  260. dj = ([((u[j] - u[i]) % 1, j) for j in range(i + 1, len(u))])
  261. for one, j in dj:
  262. if one.p == 1 and one.q != 1:
  263. n = one.q
  264. got = [i]
  265. get = list(range(1, n))
  266. for d, j in dj:
  267. m = n*d
  268. if m.is_Integer and m in get:
  269. get.remove(m)
  270. got.append(j)
  271. if not get:
  272. break
  273. else:
  274. continue
  275. for i, j in enumerate(got):
  276. c = u[j]
  277. coeffs.remove(c)
  278. got[i] = c
  279. return one.q, got[0], got[1:]
  280. def _mult_thm(gammas, numer, denom):
  281. # pull off and analyze the leading coefficient from each gamma arg
  282. # looking for runs in those Rationals
  283. # expr -> coeff + resid -> rats[resid] = coeff
  284. rats = {}
  285. for g in gammas:
  286. c, resid = g.as_coeff_Add()
  287. rats.setdefault(resid, []).append(c)
  288. # look for runs in Rationals for each resid
  289. keys = sorted(rats, key=default_sort_key)
  290. for resid in keys:
  291. coeffs = list(sorted(rats[resid]))
  292. new = []
  293. while True:
  294. run = _run(coeffs)
  295. if run is None:
  296. break
  297. # process the sequence that was found:
  298. # 1) convert all the gamma functions to have the right
  299. # argument (could be off by an integer)
  300. # 2) append the factors corresponding to the theorem
  301. # 3) append the new gamma function
  302. n, ui, other = run
  303. # (1)
  304. for u in other:
  305. con = resid + u - 1
  306. for k in range(int(u - ui)):
  307. numer.append(con - k)
  308. con = n*(resid + ui) # for (2) and (3)
  309. # (2)
  310. numer.append((2*S.Pi)**(S(n - 1)/2)*
  311. n**(S.Half - con))
  312. # (3)
  313. new.append(con)
  314. # restore resid to coeffs
  315. rats[resid] = [resid + c for c in coeffs] + new
  316. # rebuild the gamma arguments
  317. g = []
  318. for resid in keys:
  319. g += rats[resid]
  320. # /!\ updating IN PLACE
  321. gammas[:] = g
  322. for l, numer, denom in [(numer_gammas, numer_others, denom_others),
  323. (denom_gammas, denom_others, numer_others)]:
  324. _mult_thm(l, numer, denom)
  325. # =========== level >= 2 work: factor absorption =========
  326. if level >= 2:
  327. # Try to absorb factors into the gammas: x*gamma(x) -> gamma(x + 1)
  328. # and gamma(x)/(x - 1) -> gamma(x - 1)
  329. # This code (in particular repeated calls to find_fuzzy) can be very
  330. # slow.
  331. def find_fuzzy(l, x):
  332. if not l:
  333. return
  334. S1, T1 = compute_ST(x)
  335. for y in l:
  336. S2, T2 = inv[y]
  337. if T1 != T2 or (not S1.intersection(S2) and
  338. (S1 != set() or S2 != set())):
  339. continue
  340. # XXX we want some simplification (e.g. cancel or
  341. # simplify) but no matter what it's slow.
  342. a = len(cancel(x/y).free_symbols)
  343. b = len(x.free_symbols)
  344. c = len(y.free_symbols)
  345. # TODO is there a better heuristic?
  346. if a == 0 and (b > 0 or c > 0):
  347. return y
  348. # We thus try to avoid expensive calls by building the following
  349. # "invariants": For every factor or gamma function argument
  350. # - the set of free symbols S
  351. # - the set of functional components T
  352. # We will only try to absorb if T1==T2 and (S1 intersect S2 != emptyset
  353. # or S1 == S2 == emptyset)
  354. inv = {}
  355. def compute_ST(expr):
  356. if expr in inv:
  357. return inv[expr]
  358. return (expr.free_symbols, expr.atoms(Function).union(
  359. {e.exp for e in expr.atoms(Pow)}))
  360. def update_ST(expr):
  361. inv[expr] = compute_ST(expr)
  362. for expr in numer_gammas + denom_gammas + numer_others + denom_others:
  363. update_ST(expr)
  364. for gammas, numer, denom in [(
  365. numer_gammas, numer_others, denom_others),
  366. (denom_gammas, denom_others, numer_others)]:
  367. new = []
  368. while gammas:
  369. g = gammas.pop()
  370. cont = True
  371. while cont:
  372. cont = False
  373. y = find_fuzzy(numer, g)
  374. if y is not None:
  375. numer.remove(y)
  376. if y != g:
  377. numer.append(y/g)
  378. update_ST(y/g)
  379. g += 1
  380. cont = True
  381. y = find_fuzzy(denom, g - 1)
  382. if y is not None:
  383. denom.remove(y)
  384. if y != g - 1:
  385. numer.append((g - 1)/y)
  386. update_ST((g - 1)/y)
  387. g -= 1
  388. cont = True
  389. new.append(g)
  390. # /!\ updating IN PLACE
  391. gammas[:] = new
  392. # =========== rebuild expr ==================================
  393. return Mul(*[gamma(g) for g in numer_gammas]) \
  394. / Mul(*[gamma(g) for g in denom_gammas]) \
  395. * Mul(*numer_others) / Mul(*denom_others)
  396. was = factor(expr)
  397. # (for some reason we cannot use Basic.replace in this case)
  398. expr = rule_gamma(was)
  399. if expr != was:
  400. expr = factor(expr)
  401. expr = expr.replace(gamma,
  402. lambda n: expand_func(gamma(n)) if n.is_Rational else gamma(n))
  403. return expr
  404. class _rf(Function):
  405. @classmethod
  406. def eval(cls, a, b):
  407. if b.is_Integer:
  408. if not b:
  409. return S.One
  410. n, result = int(b), S.One
  411. if n > 0:
  412. for i in range(n):
  413. result *= a + i
  414. return result
  415. elif n < 0:
  416. for i in range(1, -n + 1):
  417. result *= a - i
  418. return 1/result
  419. else:
  420. if b.is_Add:
  421. c, _b = b.as_coeff_Add()
  422. if c.is_Integer:
  423. if c > 0:
  424. return _rf(a, _b)*_rf(a + _b, c)
  425. elif c < 0:
  426. return _rf(a, _b)/_rf(a + _b + c, -c)
  427. if a.is_Add:
  428. c, _a = a.as_coeff_Add()
  429. if c.is_Integer:
  430. if c > 0:
  431. return _rf(_a, b)*_rf(_a + b, c)/_rf(_a, c)
  432. elif c < 0:
  433. return _rf(_a, b)*_rf(_a + c, -c)/_rf(_a + b + c, -c)