123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425 |
- from .functions import defun, defun_wrapped
- @defun_wrapped
- def _erf_complex(ctx, z):
- z2 = ctx.square_exp_arg(z, -1)
- #z2 = -z**2
- v = (2/ctx.sqrt(ctx.pi))*z * ctx.hyp1f1((1,2),(3,2), z2)
- if not ctx._re(z):
- v = ctx._im(v)*ctx.j
- return v
- @defun_wrapped
- def _erfc_complex(ctx, z):
- if ctx.re(z) > 2:
- z2 = ctx.square_exp_arg(z)
- nz2 = ctx.fneg(z2, exact=True)
- v = ctx.exp(nz2)/ctx.sqrt(ctx.pi) * ctx.hyperu((1,2),(1,2), z2)
- else:
- v = 1 - ctx._erf_complex(z)
- if not ctx._re(z):
- v = 1+ctx._im(v)*ctx.j
- return v
- @defun
- def erf(ctx, z):
- z = ctx.convert(z)
- if ctx._is_real_type(z):
- try:
- return ctx._erf(z)
- except NotImplementedError:
- pass
- if ctx._is_complex_type(z) and not z.imag:
- try:
- return type(z)(ctx._erf(z.real))
- except NotImplementedError:
- pass
- return ctx._erf_complex(z)
- @defun
- def erfc(ctx, z):
- z = ctx.convert(z)
- if ctx._is_real_type(z):
- try:
- return ctx._erfc(z)
- except NotImplementedError:
- pass
- if ctx._is_complex_type(z) and not z.imag:
- try:
- return type(z)(ctx._erfc(z.real))
- except NotImplementedError:
- pass
- return ctx._erfc_complex(z)
- @defun
- def square_exp_arg(ctx, z, mult=1, reciprocal=False):
- prec = ctx.prec*4+20
- if reciprocal:
- z2 = ctx.fmul(z, z, prec=prec)
- z2 = ctx.fdiv(ctx.one, z2, prec=prec)
- else:
- z2 = ctx.fmul(z, z, prec=prec)
- if mult != 1:
- z2 = ctx.fmul(z2, mult, exact=True)
- return z2
- @defun_wrapped
- def erfi(ctx, z):
- if not z:
- return z
- z2 = ctx.square_exp_arg(z)
- v = (2/ctx.sqrt(ctx.pi)*z) * ctx.hyp1f1((1,2), (3,2), z2)
- if not ctx._re(z):
- v = ctx._im(v)*ctx.j
- return v
- @defun_wrapped
- def erfinv(ctx, x):
- xre = ctx._re(x)
- if (xre != x) or (xre < -1) or (xre > 1):
- return ctx.bad_domain("erfinv(x) is defined only for -1 <= x <= 1")
- x = xre
- #if ctx.isnan(x): return x
- if not x: return x
- if x == 1: return ctx.inf
- if x == -1: return ctx.ninf
- if abs(x) < 0.9:
- a = 0.53728*x**3 + 0.813198*x
- else:
- # An asymptotic formula
- u = ctx.ln(2/ctx.pi/(abs(x)-1)**2)
- a = ctx.sign(x) * ctx.sqrt(u - ctx.ln(u))/ctx.sqrt(2)
- ctx.prec += 10
- return ctx.findroot(lambda t: ctx.erf(t)-x, a)
- @defun_wrapped
- def npdf(ctx, x, mu=0, sigma=1):
- sigma = ctx.convert(sigma)
- return ctx.exp(-(x-mu)**2/(2*sigma**2)) / (sigma*ctx.sqrt(2*ctx.pi))
- @defun_wrapped
- def ncdf(ctx, x, mu=0, sigma=1):
- a = (x-mu)/(sigma*ctx.sqrt(2))
- if a < 0:
- return ctx.erfc(-a)/2
- else:
- return (1+ctx.erf(a))/2
- @defun_wrapped
- def betainc(ctx, a, b, x1=0, x2=1, regularized=False):
- if x1 == x2:
- v = 0
- elif not x1:
- if x1 == 0 and x2 == 1:
- v = ctx.beta(a, b)
- else:
- v = x2**a * ctx.hyp2f1(a, 1-b, a+1, x2) / a
- else:
- m, d = ctx.nint_distance(a)
- if m <= 0:
- if d < -ctx.prec:
- h = +ctx.eps
- ctx.prec *= 2
- a += h
- elif d < -4:
- ctx.prec -= d
- s1 = x2**a * ctx.hyp2f1(a,1-b,a+1,x2)
- s2 = x1**a * ctx.hyp2f1(a,1-b,a+1,x1)
- v = (s1 - s2) / a
- if regularized:
- v /= ctx.beta(a,b)
- return v
- @defun
- def gammainc(ctx, z, a=0, b=None, regularized=False):
- regularized = bool(regularized)
- z = ctx.convert(z)
- if a is None:
- a = ctx.zero
- lower_modified = False
- else:
- a = ctx.convert(a)
- lower_modified = a != ctx.zero
- if b is None:
- b = ctx.inf
- upper_modified = False
- else:
- b = ctx.convert(b)
- upper_modified = b != ctx.inf
- # Complete gamma function
- if not (upper_modified or lower_modified):
- if regularized:
- if ctx.re(z) < 0:
- return ctx.inf
- elif ctx.re(z) > 0:
- return ctx.one
- else:
- return ctx.nan
- return ctx.gamma(z)
- if a == b:
- return ctx.zero
- # Standardize
- if ctx.re(a) > ctx.re(b):
- return -ctx.gammainc(z, b, a, regularized)
- # Generalized gamma
- if upper_modified and lower_modified:
- return +ctx._gamma3(z, a, b, regularized)
- # Upper gamma
- elif lower_modified:
- return ctx._upper_gamma(z, a, regularized)
- # Lower gamma
- elif upper_modified:
- return ctx._lower_gamma(z, b, regularized)
- @defun
- def _lower_gamma(ctx, z, b, regularized=False):
- # Pole
- if ctx.isnpint(z):
- return type(z)(ctx.inf)
- G = [z] * regularized
- negb = ctx.fneg(b, exact=True)
- def h(z):
- T1 = [ctx.exp(negb), b, z], [1, z, -1], [], G, [1], [1+z], b
- return (T1,)
- return ctx.hypercomb(h, [z])
- @defun
- def _upper_gamma(ctx, z, a, regularized=False):
- # Fast integer case, when available
- if ctx.isint(z):
- try:
- if regularized:
- # Gamma pole
- if ctx.isnpint(z):
- return type(z)(ctx.zero)
- orig = ctx.prec
- try:
- ctx.prec += 10
- return ctx._gamma_upper_int(z, a) / ctx.gamma(z)
- finally:
- ctx.prec = orig
- else:
- return ctx._gamma_upper_int(z, a)
- except NotImplementedError:
- pass
- # hypercomb is unable to detect the exact zeros, so handle them here
- if z == 2 and a == -1:
- return (z+a)*0
- if z == 3 and (a == -1-1j or a == -1+1j):
- return (z+a)*0
- nega = ctx.fneg(a, exact=True)
- G = [z] * regularized
- # Use 2F0 series when possible; fall back to lower gamma representation
- try:
- def h(z):
- r = z-1
- return [([ctx.exp(nega), a], [1, r], [], G, [1, -r], [], 1/nega)]
- return ctx.hypercomb(h, [z], force_series=True)
- except ctx.NoConvergence:
- def h(z):
- T1 = [], [1, z-1], [z], G, [], [], 0
- T2 = [-ctx.exp(nega), a, z], [1, z, -1], [], G, [1], [1+z], a
- return T1, T2
- return ctx.hypercomb(h, [z])
- @defun
- def _gamma3(ctx, z, a, b, regularized=False):
- pole = ctx.isnpint(z)
- if regularized and pole:
- return ctx.zero
- try:
- ctx.prec += 15
- # We don't know in advance whether it's better to write as a difference
- # of lower or upper gamma functions, so try both
- T1 = ctx.gammainc(z, a, regularized=regularized)
- T2 = ctx.gammainc(z, b, regularized=regularized)
- R = T1 - T2
- if ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
- return R
- if not pole:
- T1 = ctx.gammainc(z, 0, b, regularized=regularized)
- T2 = ctx.gammainc(z, 0, a, regularized=regularized)
- R = T1 - T2
- # May be ok, but should probably at least print a warning
- # about possible cancellation
- if 1: #ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
- return R
- finally:
- ctx.prec -= 15
- raise NotImplementedError
- @defun_wrapped
- def expint(ctx, n, z):
- if ctx.isint(n) and ctx._is_real_type(z):
- try:
- return ctx._expint_int(n, z)
- except NotImplementedError:
- pass
- if ctx.isnan(n) or ctx.isnan(z):
- return z*n
- if z == ctx.inf:
- return 1/z
- if z == 0:
- # integral from 1 to infinity of t^n
- if ctx.re(n) <= 1:
- # TODO: reasonable sign of infinity
- return type(z)(ctx.inf)
- else:
- return ctx.one/(n-1)
- if n == 0:
- return ctx.exp(-z)/z
- if n == -1:
- return ctx.exp(-z)*(z+1)/z**2
- return z**(n-1) * ctx.gammainc(1-n, z)
- @defun_wrapped
- def li(ctx, z, offset=False):
- if offset:
- if z == 2:
- return ctx.zero
- return ctx.ei(ctx.ln(z)) - ctx.ei(ctx.ln2)
- if not z:
- return z
- if z == 1:
- return ctx.ninf
- return ctx.ei(ctx.ln(z))
- @defun
- def ei(ctx, z):
- try:
- return ctx._ei(z)
- except NotImplementedError:
- return ctx._ei_generic(z)
- @defun_wrapped
- def _ei_generic(ctx, z):
- # Note: the following is currently untested because mp and fp
- # both use special-case ei code
- if z == ctx.inf:
- return z
- if z == ctx.ninf:
- return ctx.zero
- if ctx.mag(z) > 1:
- try:
- r = ctx.one/z
- v = ctx.exp(z)*ctx.hyper([1,1],[],r,
- maxterms=ctx.prec, force_series=True)/z
- im = ctx._im(z)
- if im > 0:
- v += ctx.pi*ctx.j
- if im < 0:
- v -= ctx.pi*ctx.j
- return v
- except ctx.NoConvergence:
- pass
- v = z*ctx.hyp2f2(1,1,2,2,z) + ctx.euler
- if ctx._im(z):
- v += 0.5*(ctx.log(z) - ctx.log(ctx.one/z))
- else:
- v += ctx.log(abs(z))
- return v
- @defun
- def e1(ctx, z):
- try:
- return ctx._e1(z)
- except NotImplementedError:
- return ctx.expint(1, z)
- @defun
- def ci(ctx, z):
- try:
- return ctx._ci(z)
- except NotImplementedError:
- return ctx._ci_generic(z)
- @defun_wrapped
- def _ci_generic(ctx, z):
- if ctx.isinf(z):
- if z == ctx.inf: return ctx.zero
- if z == ctx.ninf: return ctx.pi*1j
- jz = ctx.fmul(ctx.j,z,exact=True)
- njz = ctx.fneg(jz,exact=True)
- v = 0.5*(ctx.ei(jz) + ctx.ei(njz))
- zreal = ctx._re(z)
- zimag = ctx._im(z)
- if zreal == 0:
- if zimag > 0: v += ctx.pi*0.5j
- if zimag < 0: v -= ctx.pi*0.5j
- if zreal < 0:
- if zimag >= 0: v += ctx.pi*1j
- if zimag < 0: v -= ctx.pi*1j
- if ctx._is_real_type(z) and zreal > 0:
- v = ctx._re(v)
- return v
- @defun
- def si(ctx, z):
- try:
- return ctx._si(z)
- except NotImplementedError:
- return ctx._si_generic(z)
- @defun_wrapped
- def _si_generic(ctx, z):
- if ctx.isinf(z):
- if z == ctx.inf: return 0.5*ctx.pi
- if z == ctx.ninf: return -0.5*ctx.pi
- # Suffers from cancellation near 0
- if ctx.mag(z) >= -1:
- jz = ctx.fmul(ctx.j,z,exact=True)
- njz = ctx.fneg(jz,exact=True)
- v = (-0.5j)*(ctx.ei(jz) - ctx.ei(njz))
- zreal = ctx._re(z)
- if zreal > 0:
- v -= 0.5*ctx.pi
- if zreal < 0:
- v += 0.5*ctx.pi
- if ctx._is_real_type(z):
- v = ctx._re(v)
- return v
- else:
- return z*ctx.hyp1f2((1,2),(3,2),(3,2),-0.25*z*z)
- @defun_wrapped
- def chi(ctx, z):
- nz = ctx.fneg(z, exact=True)
- v = 0.5*(ctx.ei(z) + ctx.ei(nz))
- zreal = ctx._re(z)
- zimag = ctx._im(z)
- if zimag > 0:
- v += ctx.pi*0.5j
- elif zimag < 0:
- v -= ctx.pi*0.5j
- elif zreal < 0:
- v += ctx.pi*1j
- return v
- @defun_wrapped
- def shi(ctx, z):
- # Suffers from cancellation near 0
- if ctx.mag(z) >= -1:
- nz = ctx.fneg(z, exact=True)
- v = 0.5*(ctx.ei(z) - ctx.ei(nz))
- zimag = ctx._im(z)
- if zimag > 0: v -= 0.5j*ctx.pi
- if zimag < 0: v += 0.5j*ctx.pi
- return v
- else:
- return z * ctx.hyp1f2((1,2),(3,2),(3,2),0.25*z*z)
- @defun_wrapped
- def fresnels(ctx, z):
- if z == ctx.inf:
- return ctx.mpf(0.5)
- if z == ctx.ninf:
- return ctx.mpf(-0.5)
- return ctx.pi*z**3/6*ctx.hyp1f2((3,4),(3,2),(7,4),-ctx.pi**2*z**4/16)
- @defun_wrapped
- def fresnelc(ctx, z):
- if z == ctx.inf:
- return ctx.mpf(0.5)
- if z == ctx.ninf:
- return ctx.mpf(-0.5)
- return z*ctx.hyp1f2((1,4),(1,2),(5,4),-ctx.pi**2*z**4/16)
|