factorials.py 5.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187
  1. from ..libmp.backend import xrange
  2. from .functions import defun, defun_wrapped
  3. @defun
  4. def gammaprod(ctx, a, b, _infsign=False):
  5. a = [ctx.convert(x) for x in a]
  6. b = [ctx.convert(x) for x in b]
  7. poles_num = []
  8. poles_den = []
  9. regular_num = []
  10. regular_den = []
  11. for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
  12. for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
  13. # One more pole in numerator or denominator gives 0 or inf
  14. if len(poles_num) < len(poles_den): return ctx.zero
  15. if len(poles_num) > len(poles_den):
  16. # Get correct sign of infinity for x+h, h -> 0 from above
  17. # XXX: hack, this should be done properly
  18. if _infsign:
  19. a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
  20. b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
  21. return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
  22. else:
  23. return ctx.inf
  24. # All poles cancel
  25. # lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
  26. p = ctx.one
  27. orig = ctx.prec
  28. try:
  29. ctx.prec = orig + 15
  30. while poles_num:
  31. i = poles_num.pop()
  32. j = poles_den.pop()
  33. p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
  34. for x in regular_num: p *= ctx.gamma(x)
  35. for x in regular_den: p /= ctx.gamma(x)
  36. finally:
  37. ctx.prec = orig
  38. return +p
  39. @defun
  40. def beta(ctx, x, y):
  41. x = ctx.convert(x)
  42. y = ctx.convert(y)
  43. if ctx.isinf(y):
  44. x, y = y, x
  45. if ctx.isinf(x):
  46. if x == ctx.inf and not ctx._im(y):
  47. if y == ctx.ninf:
  48. return ctx.nan
  49. if y > 0:
  50. return ctx.zero
  51. if ctx.isint(y):
  52. return ctx.nan
  53. if y < 0:
  54. return ctx.sign(ctx.gamma(y)) * ctx.inf
  55. return ctx.nan
  56. xy = ctx.fadd(x, y, prec=2*ctx.prec)
  57. return ctx.gammaprod([x, y], [xy])
  58. @defun
  59. def binomial(ctx, n, k):
  60. n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
  61. k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
  62. nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
  63. return ctx.gammaprod([n1], [k1, nk1])
  64. @defun
  65. def rf(ctx, x, n):
  66. xn = ctx.fadd(x, n, prec=2*ctx.prec)
  67. return ctx.gammaprod([xn], [x])
  68. @defun
  69. def ff(ctx, x, n):
  70. x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
  71. xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
  72. return ctx.gammaprod([x1], [xn1])
  73. @defun_wrapped
  74. def fac2(ctx, x):
  75. if ctx.isinf(x):
  76. if x == ctx.inf:
  77. return x
  78. return ctx.nan
  79. return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
  80. @defun_wrapped
  81. def barnesg(ctx, z):
  82. if ctx.isinf(z):
  83. if z == ctx.inf:
  84. return z
  85. return ctx.nan
  86. if ctx.isnan(z):
  87. return z
  88. if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
  89. return z*0
  90. # Account for size (would not be needed if computing log(G))
  91. if abs(z) > 5:
  92. ctx.dps += 2*ctx.log(abs(z),2)
  93. # Reflection formula
  94. if ctx.re(z) < -ctx.dps:
  95. w = 1-z
  96. pi2 = 2*ctx.pi
  97. u = ctx.expjpi(2*w)
  98. v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
  99. ctx.j*ctx.polylog(2, u)/pi2
  100. v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
  101. if ctx._is_real_type(z):
  102. v = ctx._re(v)
  103. return v
  104. # Estimate terms for asymptotic expansion
  105. # TODO: fixme, obviously
  106. N = ctx.dps // 2 + 5
  107. G = 1
  108. while abs(z) < N or ctx.re(z) < 1:
  109. G /= ctx.gamma(z)
  110. z += 1
  111. z -= 1
  112. s = ctx.mpf(1)/12
  113. s -= ctx.log(ctx.glaisher)
  114. s += z*ctx.log(2*ctx.pi)/2
  115. s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
  116. s -= 3*z**2/4
  117. z2k = z2 = z**2
  118. for k in xrange(1, N+1):
  119. t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
  120. if abs(t) < ctx.eps:
  121. #print k, N # check how many terms were needed
  122. break
  123. z2k *= z2
  124. s += t
  125. #if k == N:
  126. # print "warning: series for barnesg failed to converge", ctx.dps
  127. return G*ctx.exp(s)
  128. @defun
  129. def superfac(ctx, z):
  130. return ctx.barnesg(z+2)
  131. @defun_wrapped
  132. def hyperfac(ctx, z):
  133. # XXX: estimate needed extra bits accurately
  134. if z == ctx.inf:
  135. return z
  136. if abs(z) > 5:
  137. extra = 4*int(ctx.log(abs(z),2))
  138. else:
  139. extra = 0
  140. ctx.prec += extra
  141. if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
  142. n = int(ctx.re(z))
  143. h = ctx.hyperfac(-n-1)
  144. if ((n+1)//2) & 1:
  145. h = -h
  146. if ctx._is_complex_type(z):
  147. return h + 0j
  148. return h
  149. zp1 = z+1
  150. # Wrong branch cut
  151. #v = ctx.gamma(zp1)**z
  152. #ctx.prec -= extra
  153. #return v / ctx.barnesg(zp1)
  154. v = ctx.exp(z*ctx.loggamma(zp1))
  155. ctx.prec -= extra
  156. return v / ctx.barnesg(zp1)
  157. '''
  158. @defun
  159. def psi0(ctx, z):
  160. """Shortcut for psi(0,z) (the digamma function)"""
  161. return ctx.psi(0, z)
  162. @defun
  163. def psi1(ctx, z):
  164. """Shortcut for psi(1,z) (the trigamma function)"""
  165. return ctx.psi(1, z)
  166. @defun
  167. def psi2(ctx, z):
  168. """Shortcut for psi(2,z) (the tetragamma function)"""
  169. return ctx.psi(2, z)
  170. @defun
  171. def psi3(ctx, z):
  172. """Shortcut for psi(3,z) (the pentagamma function)"""
  173. return ctx.psi(3, z)
  174. '''