123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423 |
- """
- Joint Random Variables Module
- See Also
- ========
- sympy.stats.rv
- sympy.stats.frv
- sympy.stats.crv
- sympy.stats.drv
- """
- from sympy.core.basic import Basic
- from sympy.core.function import Lambda
- from sympy.core.mul import prod
- from sympy.core.singleton import S
- from sympy.core.symbol import (Dummy, Symbol)
- from sympy.core.sympify import sympify
- from sympy.sets.sets import ProductSet
- from sympy.tensor.indexed import Indexed
- from sympy.concrete.products import Product
- from sympy.concrete.summations import Sum, summation
- from sympy.core.containers import Tuple
- from sympy.integrals.integrals import Integral, integrate
- from sympy.matrices import ImmutableMatrix, matrix2numpy, list2numpy
- from sympy.stats.crv import SingleContinuousDistribution, SingleContinuousPSpace
- from sympy.stats.drv import SingleDiscreteDistribution, SingleDiscretePSpace
- from sympy.stats.rv import (ProductPSpace, NamedArgsMixin, Distribution,
- ProductDomain, RandomSymbol, random_symbols,
- SingleDomain, _symbol_converter)
- from sympy.utilities.iterables import iterable
- from sympy.utilities.misc import filldedent
- from sympy.external import import_module
- # __all__ = ['marginal_distribution']
- class JointPSpace(ProductPSpace):
- """
- Represents a joint probability space. Represented using symbols for
- each component and a distribution.
- """
- def __new__(cls, sym, dist):
- if isinstance(dist, SingleContinuousDistribution):
- return SingleContinuousPSpace(sym, dist)
- if isinstance(dist, SingleDiscreteDistribution):
- return SingleDiscretePSpace(sym, dist)
- sym = _symbol_converter(sym)
- return Basic.__new__(cls, sym, dist)
- @property
- def set(self):
- return self.domain.set
- @property
- def symbol(self):
- return self.args[0]
- @property
- def distribution(self):
- return self.args[1]
- @property
- def value(self):
- return JointRandomSymbol(self.symbol, self)
- @property
- def component_count(self):
- _set = self.distribution.set
- if isinstance(_set, ProductSet):
- return S(len(_set.args))
- elif isinstance(_set, Product):
- return _set.limits[0][-1]
- return S.One
- @property
- def pdf(self):
- sym = [Indexed(self.symbol, i) for i in range(self.component_count)]
- return self.distribution(*sym)
- @property
- def domain(self):
- rvs = random_symbols(self.distribution)
- if not rvs:
- return SingleDomain(self.symbol, self.distribution.set)
- return ProductDomain(*[rv.pspace.domain for rv in rvs])
- def component_domain(self, index):
- return self.set.args[index]
- def marginal_distribution(self, *indices):
- count = self.component_count
- if count.atoms(Symbol):
- raise ValueError("Marginal distributions cannot be computed "
- "for symbolic dimensions. It is a work under progress.")
- orig = [Indexed(self.symbol, i) for i in range(count)]
- all_syms = [Symbol(str(i)) for i in orig]
- replace_dict = dict(zip(all_syms, orig))
- sym = tuple(Symbol(str(Indexed(self.symbol, i))) for i in indices)
- limits = list([i,] for i in all_syms if i not in sym)
- index = 0
- for i in range(count):
- if i not in indices:
- limits[index].append(self.distribution.set.args[i])
- limits[index] = tuple(limits[index])
- index += 1
- if self.distribution.is_Continuous:
- f = Lambda(sym, integrate(self.distribution(*all_syms), *limits))
- elif self.distribution.is_Discrete:
- f = Lambda(sym, summation(self.distribution(*all_syms), *limits))
- return f.xreplace(replace_dict)
- def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
- syms = tuple(self.value[i] for i in range(self.component_count))
- rvs = rvs or syms
- if not any(i in rvs for i in syms):
- return expr
- expr = expr*self.pdf
- for rv in rvs:
- if isinstance(rv, Indexed):
- expr = expr.xreplace({rv: Indexed(str(rv.base), rv.args[1])})
- elif isinstance(rv, RandomSymbol):
- expr = expr.xreplace({rv: rv.symbol})
- if self.value in random_symbols(expr):
- raise NotImplementedError(filldedent('''
- Expectations of expression with unindexed joint random symbols
- cannot be calculated yet.'''))
- limits = tuple((Indexed(str(rv.base),rv.args[1]),
- self.distribution.set.args[rv.args[1]]) for rv in syms)
- return Integral(expr, *limits)
- def where(self, condition):
- raise NotImplementedError()
- def compute_density(self, expr):
- raise NotImplementedError()
- def sample(self, size=(), library='scipy', seed=None):
- """
- Internal sample method
- Returns dictionary mapping RandomSymbol to realization value.
- """
- return {RandomSymbol(self.symbol, self): self.distribution.sample(size,
- library=library, seed=seed)}
- def probability(self, condition):
- raise NotImplementedError()
- class SampleJointScipy:
- """Returns the sample from scipy of the given distribution"""
- def __new__(cls, dist, size, seed=None):
- return cls._sample_scipy(dist, size, seed)
- @classmethod
- def _sample_scipy(cls, dist, size, seed):
- """Sample from SciPy."""
- import numpy
- if seed is None or isinstance(seed, int):
- rand_state = numpy.random.default_rng(seed=seed)
- else:
- rand_state = seed
- from scipy import stats as scipy_stats
- scipy_rv_map = {
- 'MultivariateNormalDistribution': lambda dist, size: scipy_stats.multivariate_normal.rvs(
- mean=matrix2numpy(dist.mu).flatten(),
- cov=matrix2numpy(dist.sigma), size=size, random_state=rand_state),
- 'MultivariateBetaDistribution': lambda dist, size: scipy_stats.dirichlet.rvs(
- alpha=list2numpy(dist.alpha, float).flatten(), size=size, random_state=rand_state),
- 'MultinomialDistribution': lambda dist, size: scipy_stats.multinomial.rvs(
- n=int(dist.n), p=list2numpy(dist.p, float).flatten(), size=size, random_state=rand_state)
- }
- sample_shape = {
- 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
- 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
- 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
- }
- dist_list = scipy_rv_map.keys()
- if dist.__class__.__name__ not in dist_list:
- return None
- samples = scipy_rv_map[dist.__class__.__name__](dist, size)
- return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
- class SampleJointNumpy:
- """Returns the sample from numpy of the given distribution"""
- def __new__(cls, dist, size, seed=None):
- return cls._sample_numpy(dist, size, seed)
- @classmethod
- def _sample_numpy(cls, dist, size, seed):
- """Sample from NumPy."""
- import numpy
- if seed is None or isinstance(seed, int):
- rand_state = numpy.random.default_rng(seed=seed)
- else:
- rand_state = seed
- numpy_rv_map = {
- 'MultivariateNormalDistribution': lambda dist, size: rand_state.multivariate_normal(
- mean=matrix2numpy(dist.mu, float).flatten(),
- cov=matrix2numpy(dist.sigma, float), size=size),
- 'MultivariateBetaDistribution': lambda dist, size: rand_state.dirichlet(
- alpha=list2numpy(dist.alpha, float).flatten(), size=size),
- 'MultinomialDistribution': lambda dist, size: rand_state.multinomial(
- n=int(dist.n), pvals=list2numpy(dist.p, float).flatten(), size=size)
- }
- sample_shape = {
- 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
- 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
- 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
- }
- dist_list = numpy_rv_map.keys()
- if dist.__class__.__name__ not in dist_list:
- return None
- samples = numpy_rv_map[dist.__class__.__name__](dist, prod(size))
- return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
- class SampleJointPymc:
- """Returns the sample from pymc3 of the given distribution"""
- def __new__(cls, dist, size, seed=None):
- return cls._sample_pymc3(dist, size, seed)
- @classmethod
- def _sample_pymc3(cls, dist, size, seed):
- """Sample from PyMC3."""
- import pymc3
- pymc3_rv_map = {
- 'MultivariateNormalDistribution': lambda dist:
- pymc3.MvNormal('X', mu=matrix2numpy(dist.mu, float).flatten(),
- cov=matrix2numpy(dist.sigma, float), shape=(1, dist.mu.shape[0])),
- 'MultivariateBetaDistribution': lambda dist:
- pymc3.Dirichlet('X', a=list2numpy(dist.alpha, float).flatten()),
- 'MultinomialDistribution': lambda dist:
- pymc3.Multinomial('X', n=int(dist.n),
- p=list2numpy(dist.p, float).flatten(), shape=(1, len(dist.p)))
- }
- sample_shape = {
- 'MultivariateNormalDistribution': lambda dist: matrix2numpy(dist.mu).flatten().shape,
- 'MultivariateBetaDistribution': lambda dist: list2numpy(dist.alpha).flatten().shape,
- 'MultinomialDistribution': lambda dist: list2numpy(dist.p).flatten().shape
- }
- dist_list = pymc3_rv_map.keys()
- if dist.__class__.__name__ not in dist_list:
- return None
- import logging
- logging.getLogger("pymc3").setLevel(logging.ERROR)
- with pymc3.Model():
- pymc3_rv_map[dist.__class__.__name__](dist)
- samples = pymc3.sample(draws=prod(size), chains=1, progressbar=False, random_seed=seed, return_inferencedata=False, compute_convergence_checks=False)[:]['X']
- return samples.reshape(size + sample_shape[dist.__class__.__name__](dist))
- _get_sample_class_jrv = {
- 'scipy': SampleJointScipy,
- 'pymc3': SampleJointPymc,
- 'numpy': SampleJointNumpy
- }
- class JointDistribution(Distribution, NamedArgsMixin):
- """
- Represented by the random variables part of the joint distribution.
- Contains methods for PDF, CDF, sampling, marginal densities, etc.
- """
- _argnames = ('pdf', )
- def __new__(cls, *args):
- args = list(map(sympify, args))
- for i in range(len(args)):
- if isinstance(args[i], list):
- args[i] = ImmutableMatrix(args[i])
- return Basic.__new__(cls, *args)
- @property
- def domain(self):
- return ProductDomain(self.symbols)
- @property
- def pdf(self):
- return self.density.args[1]
- def cdf(self, other):
- if not isinstance(other, dict):
- raise ValueError("%s should be of type dict, got %s"%(other, type(other)))
- rvs = other.keys()
- _set = self.domain.set.sets
- expr = self.pdf(tuple(i.args[0] for i in self.symbols))
- for i in range(len(other)):
- if rvs[i].is_Continuous:
- density = Integral(expr, (rvs[i], _set[i].inf,
- other[rvs[i]]))
- elif rvs[i].is_Discrete:
- density = Sum(expr, (rvs[i], _set[i].inf,
- other[rvs[i]]))
- return density
- def sample(self, size=(), library='scipy', seed=None):
- """ A random realization from the distribution """
- libraries = ['scipy', 'numpy', 'pymc3']
- if library not in libraries:
- raise NotImplementedError("Sampling from %s is not supported yet."
- % str(library))
- if not import_module(library):
- raise ValueError("Failed to import %s" % library)
- samps = _get_sample_class_jrv[library](self, size, seed=seed)
- if samps is not None:
- return samps
- raise NotImplementedError(
- "Sampling for %s is not currently implemented from %s"
- % (self.__class__.__name__, library)
- )
- def __call__(self, *args):
- return self.pdf(*args)
- class JointRandomSymbol(RandomSymbol):
- """
- Representation of random symbols with joint probability distributions
- to allow indexing."
- """
- def __getitem__(self, key):
- if isinstance(self.pspace, JointPSpace):
- if (self.pspace.component_count <= key) == True:
- raise ValueError("Index keys for %s can only up to %s." %
- (self.name, self.pspace.component_count - 1))
- return Indexed(self, key)
- class MarginalDistribution(Distribution):
- """
- Represents the marginal distribution of a joint probability space.
- Initialised using a probability distribution and random variables(or
- their indexed components) which should be a part of the resultant
- distribution.
- """
- def __new__(cls, dist, *rvs):
- if len(rvs) == 1 and iterable(rvs[0]):
- rvs = tuple(rvs[0])
- if not all(isinstance(rv, (Indexed, RandomSymbol)) for rv in rvs):
- raise ValueError(filldedent('''Marginal distribution can be
- intitialised only in terms of random variables or indexed random
- variables'''))
- rvs = Tuple.fromiter(rv for rv in rvs)
- if not isinstance(dist, JointDistribution) and len(random_symbols(dist)) == 0:
- return dist
- return Basic.__new__(cls, dist, rvs)
- def check(self):
- pass
- @property
- def set(self):
- rvs = [i for i in self.args[1] if isinstance(i, RandomSymbol)]
- return ProductSet(*[rv.pspace.set for rv in rvs])
- @property
- def symbols(self):
- rvs = self.args[1]
- return {rv.pspace.symbol for rv in rvs}
- def pdf(self, *x):
- expr, rvs = self.args[0], self.args[1]
- marginalise_out = [i for i in random_symbols(expr) if i not in rvs]
- if isinstance(expr, JointDistribution):
- count = len(expr.domain.args)
- x = Dummy('x', real=True)
- syms = tuple(Indexed(x, i) for i in count)
- expr = expr.pdf(syms)
- else:
- syms = tuple(rv.pspace.symbol if isinstance(rv, RandomSymbol) else rv.args[0] for rv in rvs)
- return Lambda(syms, self.compute_pdf(expr, marginalise_out))(*x)
- def compute_pdf(self, expr, rvs):
- for rv in rvs:
- lpdf = 1
- if isinstance(rv, RandomSymbol):
- lpdf = rv.pspace.pdf
- expr = self.marginalise_out(expr*lpdf, rv)
- return expr
- def marginalise_out(self, expr, rv):
- from sympy.concrete.summations import Sum
- if isinstance(rv, RandomSymbol):
- dom = rv.pspace.set
- elif isinstance(rv, Indexed):
- dom = rv.base.component_domain(
- rv.pspace.component_domain(rv.args[1]))
- expr = expr.xreplace({rv: rv.pspace.symbol})
- if rv.pspace.is_Continuous:
- #TODO: Modify to support integration
- #for all kinds of sets.
- expr = Integral(expr, (rv.pspace.symbol, dom))
- elif rv.pspace.is_Discrete:
- #incorporate this into `Sum`/`summation`
- if dom in (S.Integers, S.Naturals, S.Naturals0):
- dom = (dom.inf, dom.sup)
- expr = Sum(expr, (rv.pspace.symbol, dom))
- return expr
- def __call__(self, *args):
- return self.pdf(*args)
|