joint_rv.py 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423
  1. """
  2. Joint Random Variables Module
  3. See Also
  4. ========
  5. sympy.stats.rv
  6. sympy.stats.frv
  7. sympy.stats.crv
  8. sympy.stats.drv
  9. """
  10. from sympy.core.basic import Basic
  11. from sympy.core.function import Lambda
  12. from sympy.core.mul import prod
  13. from sympy.core.singleton import S
  14. from sympy.core.symbol import (Dummy, Symbol)
  15. from sympy.core.sympify import sympify
  16. from sympy.sets.sets import ProductSet
  17. from sympy.tensor.indexed import Indexed
  18. from sympy.concrete.products import Product
  19. from sympy.concrete.summations import Sum, summation
  20. from sympy.core.containers import Tuple
  21. from sympy.integrals.integrals import Integral, integrate
  22. from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy
  23. from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace
  24. from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace
  25. from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution,
  26. ProductDomain, RandomSymbol, random_symbols,
  27. SingleDomain, _symbol_converter)
  28. from sympy.utilities.iterables import iterable
  29. from sympy.utilities.misc import filldedent
  30. from sympy.external import import_module
  31. # __all__ = ['marginal_distribution']
  32. class JointPSpace(ProductPSpace):
  33. """
  34. Represents a joint probability space. Represented using symbols for
  35. each component and a distribution.
  36. """
  37. def __new__(cls, sym, dist):
  38. if isinstance(dist, SingleContinuousDistribution):
  39. return SingleContinuousPSpace(sym, dist)
  40. if isinstance(dist, SingleDiscreteDistribution):
  41. return SingleDiscretePSpace(sym, dist)
  42. sym = _symbol_converter(sym)
  43. return Basic.__new__(cls, sym, dist)
  44. @property
  45. def set(self):
  46. return self.domain.set
  47. @property
  48. def symbol(self):
  49. return self.args[0]
  50. @property
  51. def distribution(self):
  52. return self.args[1]
  53. @property
  54. def value(self):
  55. return JointRandomSymbol(self.symbol, self)
  56. @property
  57. def component_count(self):
  58. _set = self.distribution.set
  59. if isinstance(_set, ProductSet):
  60. return S(len(_set.args))
  61. elif isinstance(_set, Product):
  62. return _set.limits[0][-1]
  63. return S.One
  64. @property
  65. def pdf(self):
  66. sym = [Indexed(self.symbol, i) for i in range(self.component_count)]
  67. return self.distribution(*sym)
  68. @property
  69. def domain(self):
  70. rvs = random_symbols(self.distribution)
  71. if not rvs:
  72. return SingleDomain(self.symbol, self.distribution.set)
  73. return ProductDomain(*[rv.pspace.domain for rv in rvs])
  74. def component_domain(self, index):
  75. return self.set.args[index]
  76. def marginal_distribution(self, *indices):
  77. count = self.component_count
  78. if count.atoms(Symbol):
  79. raise ValueError("Marginal distributions cannot be computed "
  80. "for symbolic dimensions. It is a work under progress.")
  81. orig = [Indexed(self.symbol, i) for i in range(count)]
  82. all_syms = [Symbol(str(i)) for i in orig]
  83. replace_dict = dict(zip(all_syms, orig))
  84. sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices)
  85. limits = list([i,] for i in all_syms if i not in sym)
  86. index = 0
  87. for i in range(count):
  88. if i not in indices:
  89. limits[index].append(self.distribution.set.args[i])
  90. limits[index] = tuple(limits[index])
  91. index += 1
  92. if self.distribution.is_Continuous:
  93. f = Lambda(sym, integrate(self.distribution(*all_syms), *limits))
  94. elif self.distribution.is_Discrete:
  95. f = Lambda(sym, summation(self.distribution(*all_syms), *limits))
  96. return f.xreplace(replace_dict)
  97. def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
  98. syms = tuple(self.value[i] for i in range(self.component_count))
  99. rvs = rvs or syms
  100. if not any(i in rvs for i in syms):
  101. return expr
  102. expr = expr*self.pdf
  103. for rv in rvs:
  104. if isinstance(rv, Indexed):
  105. expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])})
  106. elif isinstance(rv, RandomSymbol):
  107. expr = expr.xreplace({rv: rv.symbol})
  108. if self.value in random_symbols(expr):
  109. raise NotImplementedError(filldedent('''
  110. Expectations of expression with unindexed joint random symbols
  111. cannot be calculated yet.'''))
  112. limits = tuple((Indexed(str(rv.base),rv.args[1]),
  113. self.distribution.set.args[rv.args[1]]) for rv in syms)
  114. return Integral(expr, *limits)
  115. def where(self, condition):
  116. raise NotImplementedError()
  117. def compute_density(self, expr):
  118. raise NotImplementedError()
  119. def sample(self, size=(), library='scipy', seed=None):
  120. """
  121. Internal sample method
  122. Returns dictionary mapping RandomSymbol to realization value.
  123. """
  124. return {RandomSymbol(self.symbol, self): self.distribution.sample(size,
  125. library=library, seed=seed)}
  126. def probability(self, condition):
  127. raise NotImplementedError()
  128. class SampleJointScipy:
  129. """Returns the sample from scipy of the given distribution"""
  130. def __new__(cls, dist, size, seed=None):
  131. return cls._sample_scipy(dist, size, seed)
  132. @classmethod
  133. def _sample_scipy(cls, dist, size, seed):
  134. """Sample from SciPy."""
  135. import numpy
  136. if seed is None or isinstance(seed, int):
  137. rand_state = numpy.random.default_rng(seed=seed)
  138. else:
  139. rand_state = seed
  140. from scipy import stats as scipy_stats
  141. scipy_rv_map = {
  142. 'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs(
  143. mean=matrix2numpy(dist.mu).flatten(),
  144. cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state),
  145. 'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs(
  146. alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state),
  147. 'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs(
  148. n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state)
  149. }
  150. sample_shape = {
  151. 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
  152. 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
  153. 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
  154. }
  155. dist_list = scipy_rv_map.keys()
  156. if dist.__class__.__name__ not in dist_list:
  157. return None
  158. samples = scipy_rv_map[dist.__class__.__name__](dist, size)
  159. return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
  160. class SampleJointNumpy:
  161. """Returns the sample from numpy of the given distribution"""
  162. def __new__(cls, dist, size, seed=None):
  163. return cls._sample_numpy(dist, size, seed)
  164. @classmethod
  165. def _sample_numpy(cls, dist, size, seed):
  166. """Sample from NumPy."""
  167. import numpy
  168. if seed is None or isinstance(seed, int):
  169. rand_state = numpy.random.default_rng(seed=seed)
  170. else:
  171. rand_state = seed
  172. numpy_rv_map = {
  173. 'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal(
  174. mean=matrix2numpy(dist.mu, float).flatten(),
  175. cov=matrix2numpy(dist.sigma, float), size=size),
  176. 'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet(
  177. alpha=list2numpy(dist.alpha, float).flatten(), size=size),
  178. 'MultinomialDistribution': lambda dist, size: rand_state.multinomial(
  179. n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size)
  180. }
  181. sample_shape = {
  182. 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
  183. 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
  184. 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
  185. }
  186. dist_list = numpy_rv_map.keys()
  187. if dist.__class__.__name__ not in dist_list:
  188. return None
  189. samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size))
  190. return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
  191. class SampleJointPymc:
  192. """Returns the sample from pymc3 of the given distribution"""
  193. def __new__(cls, dist, size, seed=None):
  194. return cls._sample_pymc3(dist, size, seed)
  195. @classmethod
  196. def _sample_pymc3(cls, dist, size, seed):
  197. """Sample from PyMC3."""
  198. import pymc3
  199. pymc3_rv_map = {
  200. 'MultivariateNormalDistribution': lambda dist:
  201. pymc3.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(),
  202. cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])),
  203. 'MultivariateBetaDistribution': lambda dist:
  204. pymc3.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()),
  205. 'MultinomialDistribution': lambda dist:
  206. pymc3.Multinomial('X', n=int(dist.n),
  207. p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p)))
  208. }
  209. sample_shape = {
  210. 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
  211. 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
  212. 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
  213. }
  214. dist_list = pymc3_rv_map.keys()
  215. if dist.__class__.__name__ not in dist_list:
  216. return None
  217. import logging
  218. logging.getLogger("pymc3").setLevel(logging.ERROR)
  219. with pymc3.Model():
  220. pymc3_rv_map[dist.__class__.__name__](dist)
  221. samples = pymc3.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X']
  222. return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
  223. _get_sample_class_jrv = {
  224. 'scipy': SampleJointScipy,
  225. 'pymc3': SampleJointPymc,
  226. 'numpy': SampleJointNumpy
  227. }
  228. class JointDistribution(Distribution, NamedArgsMixin):
  229. """
  230. Represented by the random variables part of the joint distribution.
  231. Contains methods for PDF, CDF, sampling, marginal densities, etc.
  232. """
  233. _argnames = ('pdf', )
  234. def __new__(cls, *args):
  235. args = list(map(sympify, args))
  236. for i in range(len(args)):
  237. if isinstance(args[i], list):
  238. args[i] = ImmutableMatrix(args[i])
  239. return Basic.__new__(cls, *args)
  240. @property
  241. def domain(self):
  242. return ProductDomain(self.symbols)
  243. @property
  244. def pdf(self):
  245. return self.density.args[1]
  246. def cdf(self, other):
  247. if not isinstance(other, dict):
  248. raise ValueError("%s should be of type dict, got %s"%(other, type(other)))
  249. rvs = other.keys()
  250. _set = self.domain.set.sets
  251. expr = self.pdf(tuple(i.args[0] for i in self.symbols))
  252. for i in range(len(other)):
  253. if rvs[i].is_Continuous:
  254. density = Integral(expr, (rvs[i], _set[i].inf,
  255. other[rvs[i]]))
  256. elif rvs[i].is_Discrete:
  257. density = Sum(expr, (rvs[i], _set[i].inf,
  258. other[rvs[i]]))
  259. return density
  260. def sample(self, size=(), library='scipy', seed=None):
  261. """ A random realization from the distribution """
  262. libraries = ['scipy', 'numpy', 'pymc3']
  263. if library not in libraries:
  264. raise NotImplementedError("Sampling from %s is not supported yet."
  265. % str(library))
  266. if not import_module(library):
  267. raise ValueError("Failed to import %s" % library)
  268. samps = _get_sample_class_jrv[library](self, size, seed=seed)
  269. if samps is not None:
  270. return samps
  271. raise NotImplementedError(
  272. "Sampling for %s is not currently implemented from %s"
  273. % (self.__class__.__name__, library)
  274. )
  275. def __call__(self, *args):
  276. return self.pdf(*args)
  277. class JointRandomSymbol(RandomSymbol):
  278. """
  279. Representation of random symbols with joint probability distributions
  280. to allow indexing."
  281. """
  282. def __getitem__(self, key):
  283. if isinstance(self.pspace, JointPSpace):
  284. if (self.pspace.component_count <= key) == True:
  285. raise ValueError("Index keys for %s can only up to %s." %
  286. (self.name, self.pspace.component_count - 1))
  287. return Indexed(self, key)
  288. class MarginalDistribution(Distribution):
  289. """
  290. Represents the marginal distribution of a joint probability space.
  291. Initialised using a probability distribution and random variables(or
  292. their indexed components) which should be a part of the resultant
  293. distribution.
  294. """
  295. def __new__(cls, dist, *rvs):
  296. if len(rvs) == 1 and iterable(rvs[0]):
  297. rvs = tuple(rvs[0])
  298. if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs):
  299. raise ValueError(filldedent('''Marginal distribution can be
  300. intitialised only in terms of random variables or indexed random
  301. variables'''))
  302. rvs = Tuple.fromiter(rv for rv in rvs)
  303. if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0:
  304. return dist
  305. return Basic.__new__(cls, dist, rvs)
  306. def check(self):
  307. pass
  308. @property
  309. def set(self):
  310. rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)]
  311. return ProductSet(*[rv.pspace.set for rv in rvs])
  312. @property
  313. def symbols(self):
  314. rvs = self.args[1]
  315. return {rv.pspace.symbol for rv in rvs}
  316. def pdf(self, *x):
  317. expr, rvs = self.args[0], self.args[1]
  318. marginalise_out = [i for i in random_symbols(expr) if i not in rvs]
  319. if isinstance(expr, JointDistribution):
  320. count = len(expr.domain.args)
  321. x = Dummy('x', real=True)
  322. syms = tuple(Indexed(x, i) for i in count)
  323. expr = expr.pdf(syms)
  324. else:
  325. syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs)
  326. return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
  327. def compute_pdf(self, expr, rvs):
  328. for rv in rvs:
  329. lpdf = 1
  330. if isinstance(rv, RandomSymbol):
  331. lpdf = rv.pspace.pdf
  332. expr = self.marginalise_out(expr*lpdf, rv)
  333. return expr
  334. def marginalise_out(self, expr, rv):
  335. from sympy.concrete.summations import Sum
  336. if isinstance(rv, RandomSymbol):
  337. dom = rv.pspace.set
  338. elif isinstance(rv, Indexed):
  339. dom = rv.base.component_domain(
  340. rv.pspace.component_domain(rv.args[1]))
  341. expr = expr.xreplace({rv: rv.pspace.symbol})
  342. if rv.pspace.is_Continuous:
  343. #TODO: Modify to support integration
  344. #for all kinds of sets.
  345. expr = Integral(expr, (rv.pspace.symbol, dom))
  346. elif rv.pspace.is_Discrete:
  347. #incorporate this into `Sum`/`summation`
  348. if dom in (S.Integers, S.Naturals, S.Naturals0):
  349. dom = (dom.inf, dom.sup)
  350. expr = Sum(expr, (rv.pspace.symbol, dom))
  351. return expr
  352. def __call__(self, *args):
  353. return self.pdf(*args)