123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187 |
- from ..libmp.backend import xrange
- from .functions import defun, defun_wrapped
- @defun
- def gammaprod(ctx, a, b, _infsign=False):
- a = [ctx.convert(x) for x in a]
- b = [ctx.convert(x) for x in b]
- poles_num = []
- poles_den = []
- regular_num = []
- regular_den = []
- for x in a: [regular_num, poles_num][ctx.isnpint(x)].append(x)
- for x in b: [regular_den, poles_den][ctx.isnpint(x)].append(x)
- # One more pole in numerator or denominator gives 0 or inf
- if len(poles_num) < len(poles_den): return ctx.zero
- if len(poles_num) > len(poles_den):
- # Get correct sign of infinity for x+h, h -> 0 from above
- # XXX: hack, this should be done properly
- if _infsign:
- a = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_num]
- b = [x and x*(1+ctx.eps) or x+ctx.eps for x in poles_den]
- return ctx.sign(ctx.gammaprod(a+regular_num,b+regular_den)) * ctx.inf
- else:
- return ctx.inf
- # All poles cancel
- # lim G(i)/G(j) = (-1)**(i+j) * gamma(1-j) / gamma(1-i)
- p = ctx.one
- orig = ctx.prec
- try:
- ctx.prec = orig + 15
- while poles_num:
- i = poles_num.pop()
- j = poles_den.pop()
- p *= (-1)**(i+j) * ctx.gamma(1-j) / ctx.gamma(1-i)
- for x in regular_num: p *= ctx.gamma(x)
- for x in regular_den: p /= ctx.gamma(x)
- finally:
- ctx.prec = orig
- return +p
- @defun
- def beta(ctx, x, y):
- x = ctx.convert(x)
- y = ctx.convert(y)
- if ctx.isinf(y):
- x, y = y, x
- if ctx.isinf(x):
- if x == ctx.inf and not ctx._im(y):
- if y == ctx.ninf:
- return ctx.nan
- if y > 0:
- return ctx.zero
- if ctx.isint(y):
- return ctx.nan
- if y < 0:
- return ctx.sign(ctx.gamma(y)) * ctx.inf
- return ctx.nan
- xy = ctx.fadd(x, y, prec=2*ctx.prec)
- return ctx.gammaprod([x, y], [xy])
- @defun
- def binomial(ctx, n, k):
- n1 = ctx.fadd(n, 1, prec=2*ctx.prec)
- k1 = ctx.fadd(k, 1, prec=2*ctx.prec)
- nk1 = ctx.fsub(n1, k, prec=2*ctx.prec)
- return ctx.gammaprod([n1], [k1, nk1])
- @defun
- def rf(ctx, x, n):
- xn = ctx.fadd(x, n, prec=2*ctx.prec)
- return ctx.gammaprod([xn], [x])
- @defun
- def ff(ctx, x, n):
- x1 = ctx.fadd(x, 1, prec=2*ctx.prec)
- xn1 = ctx.fadd(ctx.fsub(x, n, prec=2*ctx.prec), 1, prec=2*ctx.prec)
- return ctx.gammaprod([x1], [xn1])
- @defun_wrapped
- def fac2(ctx, x):
- if ctx.isinf(x):
- if x == ctx.inf:
- return x
- return ctx.nan
- return 2**(x/2)*(ctx.pi/2)**((ctx.cospi(x)-1)/4)*ctx.gamma(x/2+1)
- @defun_wrapped
- def barnesg(ctx, z):
- if ctx.isinf(z):
- if z == ctx.inf:
- return z
- return ctx.nan
- if ctx.isnan(z):
- return z
- if (not ctx._im(z)) and ctx._re(z) <= 0 and ctx.isint(ctx._re(z)):
- return z*0
- # Account for size (would not be needed if computing log(G))
- if abs(z) > 5:
- ctx.dps += 2*ctx.log(abs(z),2)
- # Reflection formula
- if ctx.re(z) < -ctx.dps:
- w = 1-z
- pi2 = 2*ctx.pi
- u = ctx.expjpi(2*w)
- v = ctx.j*ctx.pi/12 - ctx.j*ctx.pi*w**2/2 + w*ctx.ln(1-u) - \
- ctx.j*ctx.polylog(2, u)/pi2
- v = ctx.barnesg(2-z)*ctx.exp(v)/pi2**w
- if ctx._is_real_type(z):
- v = ctx._re(v)
- return v
- # Estimate terms for asymptotic expansion
- # TODO: fixme, obviously
- N = ctx.dps // 2 + 5
- G = 1
- while abs(z) < N or ctx.re(z) < 1:
- G /= ctx.gamma(z)
- z += 1
- z -= 1
- s = ctx.mpf(1)/12
- s -= ctx.log(ctx.glaisher)
- s += z*ctx.log(2*ctx.pi)/2
- s += (z**2/2-ctx.mpf(1)/12)*ctx.log(z)
- s -= 3*z**2/4
- z2k = z2 = z**2
- for k in xrange(1, N+1):
- t = ctx.bernoulli(2*k+2) / (4*k*(k+1)*z2k)
- if abs(t) < ctx.eps:
- #print k, N # check how many terms were needed
- break
- z2k *= z2
- s += t
- #if k == N:
- # print "warning: series for barnesg failed to converge", ctx.dps
- return G*ctx.exp(s)
- @defun
- def superfac(ctx, z):
- return ctx.barnesg(z+2)
- @defun_wrapped
- def hyperfac(ctx, z):
- # XXX: estimate needed extra bits accurately
- if z == ctx.inf:
- return z
- if abs(z) > 5:
- extra = 4*int(ctx.log(abs(z),2))
- else:
- extra = 0
- ctx.prec += extra
- if not ctx._im(z) and ctx._re(z) < 0 and ctx.isint(ctx._re(z)):
- n = int(ctx.re(z))
- h = ctx.hyperfac(-n-1)
- if ((n+1)//2) & 1:
- h = -h
- if ctx._is_complex_type(z):
- return h + 0j
- return h
- zp1 = z+1
- # Wrong branch cut
- #v = ctx.gamma(zp1)**z
- #ctx.prec -= extra
- #return v / ctx.barnesg(zp1)
- v = ctx.exp(z*ctx.loggamma(zp1))
- ctx.prec -= extra
- return v / ctx.barnesg(zp1)
- '''
- @defun
- def psi0(ctx, z):
- """Shortcut for psi(0,z) (the digamma function)"""
- return ctx.psi(0, z)
- @defun
- def psi1(ctx, z):
- """Shortcut for psi(1,z) (the trigamma function)"""
- return ctx.psi(1, z)
- @defun
- def psi2(ctx, z):
- """Shortcut for psi(2,z) (the tetragamma function)"""
- return ctx.psi(2, z)
- @defun
- def psi3(ctx, z):
- """Shortcut for psi(3,z) (the pentagamma function)"""
- return ctx.psi(3, z)
- '''
|