drv.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350
  1. from sympy.concrete.summations import (Sum, summation)
  2. from sympy.core.basic import Basic
  3. from sympy.core.cache import cacheit
  4. from sympy.core.function import Lambda
  5. from sympy.core.numbers import I
  6. from sympy.core.relational import (Eq, Ne)
  7. from sympy.core.singleton import S
  8. from sympy.core.symbol import (Dummy, symbols)
  9. from sympy.core.sympify import sympify
  10. from sympy.functions.combinatorial.factorials import factorial
  11. from sympy.functions.elementary.exponential import exp
  12. from sympy.functions.elementary.integers import floor
  13. from sympy.functions.elementary.piecewise import Piecewise
  14. from sympy.logic.boolalg import And
  15. from sympy.polys.polytools import poly
  16. from sympy.series.series import series
  17. from sympy.polys.polyerrors import PolynomialError
  18. from sympy.stats.crv import reduce_rational_inequalities_wrap
  19. from sympy.stats.rv import (NamedArgsMixin, SinglePSpace, SingleDomain,
  20. random_symbols, PSpace, ConditionalDomain, RandomDomain,
  21. ProductDomain, Distribution)
  22. from sympy.stats.symbolic_probability import Probability
  23. from sympy.sets.fancysets import Range, FiniteSet
  24. from sympy.sets.sets import Union
  25. from sympy.sets.contains import Contains
  26. from sympy.utilities import filldedent
  27. from sympy.core.sympify import _sympify
  28. class DiscreteDistribution(Distribution):
  29. def __call__(self, *args):
  30. return self.pdf(*args)
  31. class SingleDiscreteDistribution(DiscreteDistribution, NamedArgsMixin):
  32. """ Discrete distribution of a single variable.
  33. Serves as superclass for PoissonDistribution etc....
  34. Provides methods for pdf, cdf, and sampling
  35. See Also:
  36. sympy.stats.crv_types.*
  37. """
  38. set = S.Integers
  39. def __new__(cls, *args):
  40. args = list(map(sympify, args))
  41. return Basic.__new__(cls, *args)
  42. @staticmethod
  43. def check(*args):
  44. pass
  45. @cacheit
  46. def compute_cdf(self, **kwargs):
  47. """ Compute the CDF from the PDF.
  48. Returns a Lambda.
  49. """
  50. x = symbols('x', integer=True, cls=Dummy)
  51. z = symbols('z', real=True, cls=Dummy)
  52. left_bound = self.set.inf
  53. # CDF is integral of PDF from left bound to z
  54. pdf = self.pdf(x)
  55. cdf = summation(pdf, (x, left_bound, floor(z)), **kwargs)
  56. # CDF Ensure that CDF left of left_bound is zero
  57. cdf = Piecewise((cdf, z >= left_bound), (0, True))
  58. return Lambda(z, cdf)
  59. def _cdf(self, x):
  60. return None
  61. def cdf(self, x, **kwargs):
  62. """ Cumulative density function """
  63. if not kwargs:
  64. cdf = self._cdf(x)
  65. if cdf is not None:
  66. return cdf
  67. return self.compute_cdf(**kwargs)(x)
  68. @cacheit
  69. def compute_characteristic_function(self, **kwargs):
  70. """ Compute the characteristic function from the PDF.
  71. Returns a Lambda.
  72. """
  73. x, t = symbols('x, t', real=True, cls=Dummy)
  74. pdf = self.pdf(x)
  75. cf = summation(exp(I*t*x)*pdf, (x, self.set.inf, self.set.sup))
  76. return Lambda(t, cf)
  77. def _characteristic_function(self, t):
  78. return None
  79. def characteristic_function(self, t, **kwargs):
  80. """ Characteristic function """
  81. if not kwargs:
  82. cf = self._characteristic_function(t)
  83. if cf is not None:
  84. return cf
  85. return self.compute_characteristic_function(**kwargs)(t)
  86. @cacheit
  87. def compute_moment_generating_function(self, **kwargs):
  88. t = Dummy('t', real=True)
  89. x = Dummy('x', integer=True)
  90. pdf = self.pdf(x)
  91. mgf = summation(exp(t*x)*pdf, (x, self.set.inf, self.set.sup))
  92. return Lambda(t, mgf)
  93. def _moment_generating_function(self, t):
  94. return None
  95. def moment_generating_function(self, t, **kwargs):
  96. if not kwargs:
  97. mgf = self._moment_generating_function(t)
  98. if mgf is not None:
  99. return mgf
  100. return self.compute_moment_generating_function(**kwargs)(t)
  101. @cacheit
  102. def compute_quantile(self, **kwargs):
  103. """ Compute the Quantile from the PDF.
  104. Returns a Lambda.
  105. """
  106. x = Dummy('x', integer=True)
  107. p = Dummy('p', real=True)
  108. left_bound = self.set.inf
  109. pdf = self.pdf(x)
  110. cdf = summation(pdf, (x, left_bound, x), **kwargs)
  111. set = ((x, p <= cdf), )
  112. return Lambda(p, Piecewise(*set))
  113. def _quantile(self, x):
  114. return None
  115. def quantile(self, x, **kwargs):
  116. """ Cumulative density function """
  117. if not kwargs:
  118. quantile = self._quantile(x)
  119. if quantile is not None:
  120. return quantile
  121. return self.compute_quantile(**kwargs)(x)
  122. def expectation(self, expr, var, evaluate=True, **kwargs):
  123. """ Expectation of expression over distribution """
  124. # TODO: support discrete sets with non integer stepsizes
  125. if evaluate:
  126. try:
  127. p = poly(expr, var)
  128. t = Dummy('t', real=True)
  129. mgf = self.moment_generating_function(t)
  130. deg = p.degree()
  131. taylor = poly(series(mgf, t, 0, deg + 1).removeO(), t)
  132. result = 0
  133. for k in range(deg+1):
  134. result += p.coeff_monomial(var ** k) * taylor.coeff_monomial(t ** k) * factorial(k)
  135. return result
  136. except PolynomialError:
  137. return summation(expr * self.pdf(var),
  138. (var, self.set.inf, self.set.sup), **kwargs)
  139. else:
  140. return Sum(expr * self.pdf(var),
  141. (var, self.set.inf, self.set.sup), **kwargs)
  142. def __call__(self, *args):
  143. return self.pdf(*args)
  144. class DiscreteDomain(RandomDomain):
  145. """
  146. A domain with discrete support with step size one.
  147. Represented using symbols and Range.
  148. """
  149. is_Discrete = True
  150. class SingleDiscreteDomain(DiscreteDomain, SingleDomain):
  151. def as_boolean(self):
  152. return Contains(self.symbol, self.set)
  153. class ConditionalDiscreteDomain(DiscreteDomain, ConditionalDomain):
  154. """
  155. Domain with discrete support of step size one, that is restricted by
  156. some condition.
  157. """
  158. @property
  159. def set(self):
  160. rv = self.symbols
  161. if len(self.symbols) > 1:
  162. raise NotImplementedError(filldedent('''
  163. Multivariate conditional domains are not yet implemented.'''))
  164. rv = list(rv)[0]
  165. return reduce_rational_inequalities_wrap(self.condition,
  166. rv).intersect(self.fulldomain.set)
  167. class DiscretePSpace(PSpace):
  168. is_real = True
  169. is_Discrete = True
  170. @property
  171. def pdf(self):
  172. return self.density(*self.symbols)
  173. def where(self, condition):
  174. rvs = random_symbols(condition)
  175. assert all(r.symbol in self.symbols for r in rvs)
  176. if len(rvs) > 1:
  177. raise NotImplementedError(filldedent('''Multivariate discrete
  178. random variables are not yet supported.'''))
  179. conditional_domain = reduce_rational_inequalities_wrap(condition,
  180. rvs[0])
  181. conditional_domain = conditional_domain.intersect(self.domain.set)
  182. return SingleDiscreteDomain(rvs[0].symbol, conditional_domain)
  183. def probability(self, condition):
  184. complement = isinstance(condition, Ne)
  185. if complement:
  186. condition = Eq(condition.args[0], condition.args[1])
  187. try:
  188. _domain = self.where(condition).set
  189. if condition == False or _domain is S.EmptySet:
  190. return S.Zero
  191. if condition == True or _domain == self.domain.set:
  192. return S.One
  193. prob = self.eval_prob(_domain)
  194. except NotImplementedError:
  195. from sympy.stats.rv import density
  196. expr = condition.lhs - condition.rhs
  197. dens = density(expr)
  198. if not isinstance(dens, DiscreteDistribution):
  199. from sympy.stats.drv_types import DiscreteDistributionHandmade
  200. dens = DiscreteDistributionHandmade(dens)
  201. z = Dummy('z', real=True)
  202. space = SingleDiscretePSpace(z, dens)
  203. prob = space.probability(condition.__class__(space.value, 0))
  204. if prob is None:
  205. prob = Probability(condition)
  206. return prob if not complement else S.One - prob
  207. def eval_prob(self, _domain):
  208. sym = list(self.symbols)[0]
  209. if isinstance(_domain, Range):
  210. n = symbols('n', integer=True)
  211. inf, sup, step = (r for r in _domain.args)
  212. summand = ((self.pdf).replace(
  213. sym, n*step))
  214. rv = summation(summand,
  215. (n, inf/step, (sup)/step - 1)).doit()
  216. return rv
  217. elif isinstance(_domain, FiniteSet):
  218. pdf = Lambda(sym, self.pdf)
  219. rv = sum(pdf(x) for x in _domain)
  220. return rv
  221. elif isinstance(_domain, Union):
  222. rv = sum(self.eval_prob(x) for x in _domain.args)
  223. return rv
  224. def conditional_space(self, condition):
  225. # XXX: Converting from set to tuple. The order matters to Lambda
  226. # though so we should be starting with a set...
  227. density = Lambda(tuple(self.symbols), self.pdf/self.probability(condition))
  228. condition = condition.xreplace({rv: rv.symbol for rv in self.values})
  229. domain = ConditionalDiscreteDomain(self.domain, condition)
  230. return DiscretePSpace(domain, density)
  231. class ProductDiscreteDomain(ProductDomain, DiscreteDomain):
  232. def as_boolean(self):
  233. return And(*[domain.as_boolean for domain in self.domains])
  234. class SingleDiscretePSpace(DiscretePSpace, SinglePSpace):
  235. """ Discrete probability space over a single univariate variable """
  236. is_real = True
  237. @property
  238. def set(self):
  239. return self.distribution.set
  240. @property
  241. def domain(self):
  242. return SingleDiscreteDomain(self.symbol, self.set)
  243. def sample(self, size=(), library='scipy', seed=None):
  244. """
  245. Internal sample method.
  246. Returns dictionary mapping RandomSymbol to realization value.
  247. """
  248. return {self.value: self.distribution.sample(size, library=library, seed=seed)}
  249. def compute_expectation(self, expr, rvs=None, evaluate=True, **kwargs):
  250. rvs = rvs or (self.value,)
  251. if self.value not in rvs:
  252. return expr
  253. expr = _sympify(expr)
  254. expr = expr.xreplace({rv: rv.symbol for rv in rvs})
  255. x = self.value.symbol
  256. try:
  257. return self.distribution.expectation(expr, x, evaluate=evaluate,
  258. **kwargs)
  259. except NotImplementedError:
  260. return Sum(expr * self.pdf, (x, self.set.inf, self.set.sup),
  261. **kwargs)
  262. def compute_cdf(self, expr, **kwargs):
  263. if expr == self.value:
  264. x = Dummy("x", real=True)
  265. return Lambda(x, self.distribution.cdf(x, **kwargs))
  266. else:
  267. raise NotImplementedError()
  268. def compute_density(self, expr, **kwargs):
  269. if expr == self.value:
  270. return self.distribution
  271. raise NotImplementedError()
  272. def compute_characteristic_function(self, expr, **kwargs):
  273. if expr == self.value:
  274. t = Dummy("t", real=True)
  275. return Lambda(t, self.distribution.characteristic_function(t, **kwargs))
  276. else:
  277. raise NotImplementedError()
  278. def compute_moment_generating_function(self, expr, **kwargs):
  279. if expr == self.value:
  280. t = Dummy("t", real=True)
  281. return Lambda(t, self.distribution.moment_generating_function(t, **kwargs))
  282. else:
  283. raise NotImplementedError()
  284. def compute_quantile(self, expr, **kwargs):
  285. if expr == self.value:
  286. p = Dummy("p", real=True)
  287. return Lambda(p, self.distribution.quantile(p, **kwargs))
  288. else:
  289. raise NotImplementedError()