compound_rv.py 7.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223
  1. from sympy.concrete.summations import Sum
  2. from sympy.core.basic import Basic
  3. from sympy.core.function import Lambda
  4. from sympy.core.symbol import Dummy
  5. from sympy.integrals.integrals import Integral
  6. from sympy.stats.rv import (NamedArgsMixin, random_symbols, _symbol_converter,
  7. PSpace, RandomSymbol, is_random, Distribution)
  8. from sympy.stats.crv import ContinuousDistribution, SingleContinuousPSpace
  9. from sympy.stats.drv import DiscreteDistribution, SingleDiscretePSpace
  10. from sympy.stats.frv import SingleFiniteDistribution, SingleFinitePSpace
  11. from sympy.stats.crv_types import ContinuousDistributionHandmade
  12. from sympy.stats.drv_types import DiscreteDistributionHandmade
  13. from sympy.stats.frv_types import FiniteDistributionHandmade
  14. class CompoundPSpace(PSpace):
  15. """
  16. A temporary Probability Space for the Compound Distribution. After
  17. Marginalization, this returns the corresponding Probability Space of the
  18. parent distribution.
  19. """
  20. def __new__(cls, s, distribution):
  21. s = _symbol_converter(s)
  22. if isinstance(distribution, ContinuousDistribution):
  23. return SingleContinuousPSpace(s, distribution)
  24. if isinstance(distribution, DiscreteDistribution):
  25. return SingleDiscretePSpace(s, distribution)
  26. if isinstance(distribution, SingleFiniteDistribution):
  27. return SingleFinitePSpace(s, distribution)
  28. if not isinstance(distribution, CompoundDistribution):
  29. raise ValueError("%s should be an isinstance of "
  30. "CompoundDistribution"%(distribution))
  31. return Basic.__new__(cls, s, distribution)
  32. @property
  33. def value(self):
  34. return RandomSymbol(self.symbol, self)
  35. @property
  36. def symbol(self):
  37. return self.args[0]
  38. @property
  39. def is_Continuous(self):
  40. return self.distribution.is_Continuous
  41. @property
  42. def is_Finite(self):
  43. return self.distribution.is_Finite
  44. @property
  45. def is_Discrete(self):
  46. return self.distribution.is_Discrete
  47. @property
  48. def distribution(self):
  49. return self.args[1]
  50. @property
  51. def pdf(self):
  52. return self.distribution.pdf(self.symbol)
  53. @property
  54. def set(self):
  55. return self.distribution.set
  56. @property
  57. def domain(self):
  58. return self._get_newpspace().domain
  59. def _get_newpspace(self, evaluate=False):
  60. x = Dummy('x')
  61. parent_dist = self.distribution.args[0]
  62. func = Lambda(x, self.distribution.pdf(x, evaluate))
  63. new_pspace = self._transform_pspace(self.symbol, parent_dist, func)
  64. if new_pspace is not None:
  65. return new_pspace
  66. message = ("Compound Distribution for %s is not implemeted yet" % str(parent_dist))
  67. raise NotImplementedError(message)
  68. def _transform_pspace(self, sym, dist, pdf):
  69. """
  70. This function returns the new pspace of the distribution using handmade
  71. Distributions and their corresponding pspace.
  72. """
  73. pdf = Lambda(sym, pdf(sym))
  74. _set = dist.set
  75. if isinstance(dist, ContinuousDistribution):
  76. return SingleContinuousPSpace(sym, ContinuousDistributionHandmade(pdf, _set))
  77. elif isinstance(dist, DiscreteDistribution):
  78. return SingleDiscretePSpace(sym, DiscreteDistributionHandmade(pdf, _set))
  79. elif isinstance(dist, SingleFiniteDistribution):
  80. dens = {k: pdf(k) for k in _set}
  81. return SingleFinitePSpace(sym, FiniteDistributionHandmade(dens))
  82. def compute_density(self, expr, *, compound_evaluate=True, **kwargs):
  83. new_pspace = self._get_newpspace(compound_evaluate)
  84. expr = expr.subs({self.value: new_pspace.value})
  85. return new_pspace.compute_density(expr, **kwargs)
  86. def compute_cdf(self, expr, *, compound_evaluate=True, **kwargs):
  87. new_pspace = self._get_newpspace(compound_evaluate)
  88. expr = expr.subs({self.value: new_pspace.value})
  89. return new_pspace.compute_cdf(expr, **kwargs)
  90. def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
  91. new_pspace = self._get_newpspace(evaluate)
  92. expr = expr.subs({self.value: new_pspace.value})
  93. if rvs:
  94. rvs = rvs.subs({self.value: new_pspace.value})
  95. if isinstance(new_pspace, SingleFinitePSpace):
  96. return new_pspace.compute_expectation(expr, rvs, **kwargs)
  97. return new_pspace.compute_expectation(expr, rvs, evaluate, **kwargs)
  98. def probability(self, condition, *, compound_evaluate=True, **kwargs):
  99. new_pspace = self._get_newpspace(compound_evaluate)
  100. condition = condition.subs({self.value: new_pspace.value})
  101. return new_pspace.probability(condition)
  102. def conditional_space(self, condition, *, compound_evaluate=True, **kwargs):
  103. new_pspace = self._get_newpspace(compound_evaluate)
  104. condition = condition.subs({self.value: new_pspace.value})
  105. return new_pspace.conditional_space(condition)
  106. class CompoundDistribution(Distribution, NamedArgsMixin):
  107. """
  108. Class for Compound Distributions.
  109. Parameters
  110. ==========
  111. dist : Distribution
  112. Distribution must contain a random parameter
  113. Examples
  114. ========
  115. >>> from sympy.stats.compound_rv import CompoundDistribution
  116. >>> from sympy.stats.crv_types import NormalDistribution
  117. >>> from sympy.stats import Normal
  118. >>> from sympy.abc import x
  119. >>> X = Normal('X', 2, 4)
  120. >>> N = NormalDistribution(X, 4)
  121. >>> C = CompoundDistribution(N)
  122. >>> C.set
  123. Interval(-oo, oo)
  124. >>> C.pdf(x, evaluate=True).simplify()
  125. exp(-x**2/64 + x/16 - 1/16)/(8*sqrt(pi))
  126. References
  127. ==========
  128. .. [1] https://en.wikipedia.org/wiki/Compound_probability_distribution
  129. """
  130. def __new__(cls, dist):
  131. if not isinstance(dist, (ContinuousDistribution,
  132. SingleFiniteDistribution, DiscreteDistribution)):
  133. message = "Compound Distribution for %s is not implemeted yet" % str(dist)
  134. raise NotImplementedError(message)
  135. if not cls._compound_check(dist):
  136. return dist
  137. return Basic.__new__(cls, dist)
  138. @property
  139. def set(self):
  140. return self.args[0].set
  141. @property
  142. def is_Continuous(self):
  143. return isinstance(self.args[0], ContinuousDistribution)
  144. @property
  145. def is_Finite(self):
  146. return isinstance(self.args[0], SingleFiniteDistribution)
  147. @property
  148. def is_Discrete(self):
  149. return isinstance(self.args[0], DiscreteDistribution)
  150. def pdf(self, x, evaluate=False):
  151. dist = self.args[0]
  152. randoms = [rv for rv in dist.args if is_random(rv)]
  153. if isinstance(dist, SingleFiniteDistribution):
  154. y = Dummy('y', integer=True, negative=False)
  155. expr = dist.pmf(y)
  156. else:
  157. y = Dummy('y')
  158. expr = dist.pdf(y)
  159. for rv in randoms:
  160. expr = self._marginalise(expr, rv, evaluate)
  161. return Lambda(y, expr)(x)
  162. def _marginalise(self, expr, rv, evaluate):
  163. if isinstance(rv.pspace.distribution, SingleFiniteDistribution):
  164. rv_dens = rv.pspace.distribution.pmf(rv)
  165. else:
  166. rv_dens = rv.pspace.distribution.pdf(rv)
  167. rv_dom = rv.pspace.domain.set
  168. if rv.pspace.is_Discrete or rv.pspace.is_Finite:
  169. expr = Sum(expr*rv_dens, (rv, rv_dom._inf,
  170. rv_dom._sup))
  171. else:
  172. expr = Integral(expr*rv_dens, (rv, rv_dom._inf,
  173. rv_dom._sup))
  174. if evaluate:
  175. return expr.doit()
  176. return expr
  177. @classmethod
  178. def _compound_check(self, dist):
  179. """
  180. Checks if the given distribution contains random parameters.
  181. """
  182. randoms = []
  183. for arg in dist.args:
  184. randoms.extend(random_symbols(arg))
  185. if len(randoms) == 0:
  186. return False
  187. return True