expintegrals.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. from .functions import defun, defun_wrapped
  2. @defun_wrapped
  3. def _erf_complex(ctx, z):
  4. z2 = ctx.square_exp_arg(z, -1)
  5. #z2 = -z**2
  6. v = (2/ctx.sqrt(ctx.pi))*z * ctx.hyp1f1((1,2),(3,2), z2)
  7. if not ctx._re(z):
  8. v = ctx._im(v)*ctx.j
  9. return v
  10. @defun_wrapped
  11. def _erfc_complex(ctx, z):
  12. if ctx.re(z) > 2:
  13. z2 = ctx.square_exp_arg(z)
  14. nz2 = ctx.fneg(z2, exact=True)
  15. v = ctx.exp(nz2)/ctx.sqrt(ctx.pi) * ctx.hyperu((1,2),(1,2), z2)
  16. else:
  17. v = 1 - ctx._erf_complex(z)
  18. if not ctx._re(z):
  19. v = 1+ctx._im(v)*ctx.j
  20. return v
  21. @defun
  22. def erf(ctx, z):
  23. z = ctx.convert(z)
  24. if ctx._is_real_type(z):
  25. try:
  26. return ctx._erf(z)
  27. except NotImplementedError:
  28. pass
  29. if ctx._is_complex_type(z) and not z.imag:
  30. try:
  31. return type(z)(ctx._erf(z.real))
  32. except NotImplementedError:
  33. pass
  34. return ctx._erf_complex(z)
  35. @defun
  36. def erfc(ctx, z):
  37. z = ctx.convert(z)
  38. if ctx._is_real_type(z):
  39. try:
  40. return ctx._erfc(z)
  41. except NotImplementedError:
  42. pass
  43. if ctx._is_complex_type(z) and not z.imag:
  44. try:
  45. return type(z)(ctx._erfc(z.real))
  46. except NotImplementedError:
  47. pass
  48. return ctx._erfc_complex(z)
  49. @defun
  50. def square_exp_arg(ctx, z, mult=1, reciprocal=False):
  51. prec = ctx.prec*4+20
  52. if reciprocal:
  53. z2 = ctx.fmul(z, z, prec=prec)
  54. z2 = ctx.fdiv(ctx.one, z2, prec=prec)
  55. else:
  56. z2 = ctx.fmul(z, z, prec=prec)
  57. if mult != 1:
  58. z2 = ctx.fmul(z2, mult, exact=True)
  59. return z2
  60. @defun_wrapped
  61. def erfi(ctx, z):
  62. if not z:
  63. return z
  64. z2 = ctx.square_exp_arg(z)
  65. v = (2/ctx.sqrt(ctx.pi)*z) * ctx.hyp1f1((1,2), (3,2), z2)
  66. if not ctx._re(z):
  67. v = ctx._im(v)*ctx.j
  68. return v
  69. @defun_wrapped
  70. def erfinv(ctx, x):
  71. xre = ctx._re(x)
  72. if (xre != x) or (xre < -1) or (xre > 1):
  73. return ctx.bad_domain("erfinv(x) is defined only for -1 <= x <= 1")
  74. x = xre
  75. #if ctx.isnan(x): return x
  76. if not x: return x
  77. if x == 1: return ctx.inf
  78. if x == -1: return ctx.ninf
  79. if abs(x) < 0.9:
  80. a = 0.53728*x**3 + 0.813198*x
  81. else:
  82. # An asymptotic formula
  83. u = ctx.ln(2/ctx.pi/(abs(x)-1)**2)
  84. a = ctx.sign(x) * ctx.sqrt(u - ctx.ln(u))/ctx.sqrt(2)
  85. ctx.prec += 10
  86. return ctx.findroot(lambda t: ctx.erf(t)-x, a)
  87. @defun_wrapped
  88. def npdf(ctx, x, mu=0, sigma=1):
  89. sigma = ctx.convert(sigma)
  90. return ctx.exp(-(x-mu)**2/(2*sigma**2)) / (sigma*ctx.sqrt(2*ctx.pi))
  91. @defun_wrapped
  92. def ncdf(ctx, x, mu=0, sigma=1):
  93. a = (x-mu)/(sigma*ctx.sqrt(2))
  94. if a < 0:
  95. return ctx.erfc(-a)/2
  96. else:
  97. return (1+ctx.erf(a))/2
  98. @defun_wrapped
  99. def betainc(ctx, a, b, x1=0, x2=1, regularized=False):
  100. if x1 == x2:
  101. v = 0
  102. elif not x1:
  103. if x1 == 0 and x2 == 1:
  104. v = ctx.beta(a, b)
  105. else:
  106. v = x2**a * ctx.hyp2f1(a, 1-b, a+1, x2) / a
  107. else:
  108. m, d = ctx.nint_distance(a)
  109. if m <= 0:
  110. if d < -ctx.prec:
  111. h = +ctx.eps
  112. ctx.prec *= 2
  113. a += h
  114. elif d < -4:
  115. ctx.prec -= d
  116. s1 = x2**a * ctx.hyp2f1(a,1-b,a+1,x2)
  117. s2 = x1**a * ctx.hyp2f1(a,1-b,a+1,x1)
  118. v = (s1 - s2) / a
  119. if regularized:
  120. v /= ctx.beta(a,b)
  121. return v
  122. @defun
  123. def gammainc(ctx, z, a=0, b=None, regularized=False):
  124. regularized = bool(regularized)
  125. z = ctx.convert(z)
  126. if a is None:
  127. a = ctx.zero
  128. lower_modified = False
  129. else:
  130. a = ctx.convert(a)
  131. lower_modified = a != ctx.zero
  132. if b is None:
  133. b = ctx.inf
  134. upper_modified = False
  135. else:
  136. b = ctx.convert(b)
  137. upper_modified = b != ctx.inf
  138. # Complete gamma function
  139. if not (upper_modified or lower_modified):
  140. if regularized:
  141. if ctx.re(z) < 0:
  142. return ctx.inf
  143. elif ctx.re(z) > 0:
  144. return ctx.one
  145. else:
  146. return ctx.nan
  147. return ctx.gamma(z)
  148. if a == b:
  149. return ctx.zero
  150. # Standardize
  151. if ctx.re(a) > ctx.re(b):
  152. return -ctx.gammainc(z, b, a, regularized)
  153. # Generalized gamma
  154. if upper_modified and lower_modified:
  155. return +ctx._gamma3(z, a, b, regularized)
  156. # Upper gamma
  157. elif lower_modified:
  158. return ctx._upper_gamma(z, a, regularized)
  159. # Lower gamma
  160. elif upper_modified:
  161. return ctx._lower_gamma(z, b, regularized)
  162. @defun
  163. def _lower_gamma(ctx, z, b, regularized=False):
  164. # Pole
  165. if ctx.isnpint(z):
  166. return type(z)(ctx.inf)
  167. G = [z] * regularized
  168. negb = ctx.fneg(b, exact=True)
  169. def h(z):
  170. T1 = [ctx.exp(negb), b, z], [1, z, -1], [], G, [1], [1+z], b
  171. return (T1,)
  172. return ctx.hypercomb(h, [z])
  173. @defun
  174. def _upper_gamma(ctx, z, a, regularized=False):
  175. # Fast integer case, when available
  176. if ctx.isint(z):
  177. try:
  178. if regularized:
  179. # Gamma pole
  180. if ctx.isnpint(z):
  181. return type(z)(ctx.zero)
  182. orig = ctx.prec
  183. try:
  184. ctx.prec += 10
  185. return ctx._gamma_upper_int(z, a) / ctx.gamma(z)
  186. finally:
  187. ctx.prec = orig
  188. else:
  189. return ctx._gamma_upper_int(z, a)
  190. except NotImplementedError:
  191. pass
  192. # hypercomb is unable to detect the exact zeros, so handle them here
  193. if z == 2 and a == -1:
  194. return (z+a)*0
  195. if z == 3 and (a == -1-1j or a == -1+1j):
  196. return (z+a)*0
  197. nega = ctx.fneg(a, exact=True)
  198. G = [z] * regularized
  199. # Use 2F0 series when possible; fall back to lower gamma representation
  200. try:
  201. def h(z):
  202. r = z-1
  203. return [([ctx.exp(nega), a], [1, r], [], G, [1, -r], [], 1/nega)]
  204. return ctx.hypercomb(h, [z], force_series=True)
  205. except ctx.NoConvergence:
  206. def h(z):
  207. T1 = [], [1, z-1], [z], G, [], [], 0
  208. T2 = [-ctx.exp(nega), a, z], [1, z, -1], [], G, [1], [1+z], a
  209. return T1, T2
  210. return ctx.hypercomb(h, [z])
  211. @defun
  212. def _gamma3(ctx, z, a, b, regularized=False):
  213. pole = ctx.isnpint(z)
  214. if regularized and pole:
  215. return ctx.zero
  216. try:
  217. ctx.prec += 15
  218. # We don't know in advance whether it's better to write as a difference
  219. # of lower or upper gamma functions, so try both
  220. T1 = ctx.gammainc(z, a, regularized=regularized)
  221. T2 = ctx.gammainc(z, b, regularized=regularized)
  222. R = T1 - T2
  223. if ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
  224. return R
  225. if not pole:
  226. T1 = ctx.gammainc(z, 0, b, regularized=regularized)
  227. T2 = ctx.gammainc(z, 0, a, regularized=regularized)
  228. R = T1 - T2
  229. # May be ok, but should probably at least print a warning
  230. # about possible cancellation
  231. if 1: #ctx.mag(R) - max(ctx.mag(T1), ctx.mag(T2)) > -10:
  232. return R
  233. finally:
  234. ctx.prec -= 15
  235. raise NotImplementedError
  236. @defun_wrapped
  237. def expint(ctx, n, z):
  238. if ctx.isint(n) and ctx._is_real_type(z):
  239. try:
  240. return ctx._expint_int(n, z)
  241. except NotImplementedError:
  242. pass
  243. if ctx.isnan(n) or ctx.isnan(z):
  244. return z*n
  245. if z == ctx.inf:
  246. return 1/z
  247. if z == 0:
  248. # integral from 1 to infinity of t^n
  249. if ctx.re(n) <= 1:
  250. # TODO: reasonable sign of infinity
  251. return type(z)(ctx.inf)
  252. else:
  253. return ctx.one/(n-1)
  254. if n == 0:
  255. return ctx.exp(-z)/z
  256. if n == -1:
  257. return ctx.exp(-z)*(z+1)/z**2
  258. return z**(n-1) * ctx.gammainc(1-n, z)
  259. @defun_wrapped
  260. def li(ctx, z, offset=False):
  261. if offset:
  262. if z == 2:
  263. return ctx.zero
  264. return ctx.ei(ctx.ln(z)) - ctx.ei(ctx.ln2)
  265. if not z:
  266. return z
  267. if z == 1:
  268. return ctx.ninf
  269. return ctx.ei(ctx.ln(z))
  270. @defun
  271. def ei(ctx, z):
  272. try:
  273. return ctx._ei(z)
  274. except NotImplementedError:
  275. return ctx._ei_generic(z)
  276. @defun_wrapped
  277. def _ei_generic(ctx, z):
  278. # Note: the following is currently untested because mp and fp
  279. # both use special-case ei code
  280. if z == ctx.inf:
  281. return z
  282. if z == ctx.ninf:
  283. return ctx.zero
  284. if ctx.mag(z) > 1:
  285. try:
  286. r = ctx.one/z
  287. v = ctx.exp(z)*ctx.hyper([1,1],[],r,
  288. maxterms=ctx.prec, force_series=True)/z
  289. im = ctx._im(z)
  290. if im > 0:
  291. v += ctx.pi*ctx.j
  292. if im < 0:
  293. v -= ctx.pi*ctx.j
  294. return v
  295. except ctx.NoConvergence:
  296. pass
  297. v = z*ctx.hyp2f2(1,1,2,2,z) + ctx.euler
  298. if ctx._im(z):
  299. v += 0.5*(ctx.log(z) - ctx.log(ctx.one/z))
  300. else:
  301. v += ctx.log(abs(z))
  302. return v
  303. @defun
  304. def e1(ctx, z):
  305. try:
  306. return ctx._e1(z)
  307. except NotImplementedError:
  308. return ctx.expint(1, z)
  309. @defun
  310. def ci(ctx, z):
  311. try:
  312. return ctx._ci(z)
  313. except NotImplementedError:
  314. return ctx._ci_generic(z)
  315. @defun_wrapped
  316. def _ci_generic(ctx, z):
  317. if ctx.isinf(z):
  318. if z == ctx.inf: return ctx.zero
  319. if z == ctx.ninf: return ctx.pi*1j
  320. jz = ctx.fmul(ctx.j,z,exact=True)
  321. njz = ctx.fneg(jz,exact=True)
  322. v = 0.5*(ctx.ei(jz) + ctx.ei(njz))
  323. zreal = ctx._re(z)
  324. zimag = ctx._im(z)
  325. if zreal == 0:
  326. if zimag > 0: v += ctx.pi*0.5j
  327. if zimag < 0: v -= ctx.pi*0.5j
  328. if zreal < 0:
  329. if zimag >= 0: v += ctx.pi*1j
  330. if zimag < 0: v -= ctx.pi*1j
  331. if ctx._is_real_type(z) and zreal > 0:
  332. v = ctx._re(v)
  333. return v
  334. @defun
  335. def si(ctx, z):
  336. try:
  337. return ctx._si(z)
  338. except NotImplementedError:
  339. return ctx._si_generic(z)
  340. @defun_wrapped
  341. def _si_generic(ctx, z):
  342. if ctx.isinf(z):
  343. if z == ctx.inf: return 0.5*ctx.pi
  344. if z == ctx.ninf: return -0.5*ctx.pi
  345. # Suffers from cancellation near 0
  346. if ctx.mag(z) >= -1:
  347. jz = ctx.fmul(ctx.j,z,exact=True)
  348. njz = ctx.fneg(jz,exact=True)
  349. v = (-0.5j)*(ctx.ei(jz) - ctx.ei(njz))
  350. zreal = ctx._re(z)
  351. if zreal > 0:
  352. v -= 0.5*ctx.pi
  353. if zreal < 0:
  354. v += 0.5*ctx.pi
  355. if ctx._is_real_type(z):
  356. v = ctx._re(v)
  357. return v
  358. else:
  359. return z*ctx.hyp1f2((1,2),(3,2),(3,2),-0.25*z*z)
  360. @defun_wrapped
  361. def chi(ctx, z):
  362. nz = ctx.fneg(z, exact=True)
  363. v = 0.5*(ctx.ei(z) + ctx.ei(nz))
  364. zreal = ctx._re(z)
  365. zimag = ctx._im(z)
  366. if zimag > 0:
  367. v += ctx.pi*0.5j
  368. elif zimag < 0:
  369. v -= ctx.pi*0.5j
  370. elif zreal < 0:
  371. v += ctx.pi*1j
  372. return v
  373. @defun_wrapped
  374. def shi(ctx, z):
  375. # Suffers from cancellation near 0
  376. if ctx.mag(z) >= -1:
  377. nz = ctx.fneg(z, exact=True)
  378. v = 0.5*(ctx.ei(z) - ctx.ei(nz))
  379. zimag = ctx._im(z)
  380. if zimag > 0: v -= 0.5j*ctx.pi
  381. if zimag < 0: v += 0.5j*ctx.pi
  382. return v
  383. else:
  384. return z * ctx.hyp1f2((1,2),(3,2),(3,2),0.25*z*z)
  385. @defun_wrapped
  386. def fresnels(ctx, z):
  387. if z == ctx.inf:
  388. return ctx.mpf(0.5)
  389. if z == ctx.ninf:
  390. return ctx.mpf(-0.5)
  391. return ctx.pi*z**3/6*ctx.hyp1f2((3,4),(3,2),(7,4),-ctx.pi**2*z**4/16)
  392. @defun_wrapped
  393. def fresnelc(ctx, z):
  394. if z == ctx.inf:
  395. return ctx.mpf(0.5)
  396. if z == ctx.ninf:
  397. return ctx.mpf(-0.5)
  398. return z*ctx.hyp1f2((1,4),(1,2),(5,4),-ctx.pi**2*z**4/16)