rv.py 54 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795
  1. """
  2. Main Random Variables Module
  3. Defines abstract random variable type.
  4. Contains interfaces for probability space object (PSpace) as well as standard
  5. operators, P, E, sample, density, where, quantile
  6. See Also
  7. ========
  8. sympy.stats.crv
  9. sympy.stats.frv
  10. sympy.stats.rv_interface
  11. """
  12. from functools import singledispatch
  13. from typing import Tuple as tTuple
  14. from sympy.core.add import Add
  15. from sympy.core.basic import Basic
  16. from sympy.core.containers import Tuple
  17. from sympy.core.expr import Expr
  18. from sympy.core.function import (Function, Lambda)
  19. from sympy.core.logic import fuzzy_and
  20. from sympy.core.mul import (Mul, prod)
  21. from sympy.core.relational import (Eq, Ne)
  22. from sympy.core.singleton import S
  23. from sympy.core.symbol import (Dummy, Symbol)
  24. from sympy.core.sympify import sympify
  25. from sympy.functions.special.delta_functions import DiracDelta
  26. from sympy.functions.special.tensor_functions import KroneckerDelta
  27. from sympy.logic.boolalg import (And, Or)
  28. from sympy.matrices.expressions.matexpr import MatrixSymbol
  29. from sympy.tensor.indexed import Indexed
  30. from sympy.utilities.lambdify import lambdify
  31. from sympy.core.relational import Relational
  32. from sympy.core.sympify import _sympify
  33. from sympy.sets.sets import FiniteSet, ProductSet, Intersection
  34. from sympy.solvers.solveset import solveset
  35. from sympy.external import import_module
  36. from sympy.utilities.misc import filldedent
  37. from sympy.utilities.decorator import doctest_depends_on
  38. from sympy.utilities.exceptions import sympy_deprecation_warning
  39. from sympy.utilities.iterables import iterable
  40. import warnings
  41. x = Symbol('x')
  42. @singledispatch
  43. def is_random(x):
  44. return False
  45. @is_random.register(Basic)
  46. def _(x):
  47. atoms = x.free_symbols
  48. return any(is_random(i) for i in atoms)
  49. class RandomDomain(Basic):
  50. """
  51. Represents a set of variables and the values which they can take.
  52. See Also
  53. ========
  54. sympy.stats.crv.ContinuousDomain
  55. sympy.stats.frv.FiniteDomain
  56. """
  57. is_ProductDomain = False
  58. is_Finite = False
  59. is_Continuous = False
  60. is_Discrete = False
  61. def __new__(cls, symbols, *args):
  62. symbols = FiniteSet(*symbols)
  63. return Basic.__new__(cls, symbols, *args)
  64. @property
  65. def symbols(self):
  66. return self.args[0]
  67. @property
  68. def set(self):
  69. return self.args[1]
  70. def __contains__(self, other):
  71. raise NotImplementedError()
  72. def compute_expectation(self, expr):
  73. raise NotImplementedError()
  74. class SingleDomain(RandomDomain):
  75. """
  76. A single variable and its domain.
  77. See Also
  78. ========
  79. sympy.stats.crv.SingleContinuousDomain
  80. sympy.stats.frv.SingleFiniteDomain
  81. """
  82. def __new__(cls, symbol, set):
  83. assert symbol.is_Symbol
  84. return Basic.__new__(cls, symbol, set)
  85. @property
  86. def symbol(self):
  87. return self.args[0]
  88. @property
  89. def symbols(self):
  90. return FiniteSet(self.symbol)
  91. def __contains__(self, other):
  92. if len(other) != 1:
  93. return False
  94. sym, val = tuple(other)[0]
  95. return self.symbol == sym and val in self.set
  96. class MatrixDomain(RandomDomain):
  97. """
  98. A Random Matrix variable and its domain.
  99. """
  100. def __new__(cls, symbol, set):
  101. symbol, set = _symbol_converter(symbol), _sympify(set)
  102. return Basic.__new__(cls, symbol, set)
  103. @property
  104. def symbol(self):
  105. return self.args[0]
  106. @property
  107. def symbols(self):
  108. return FiniteSet(self.symbol)
  109. class ConditionalDomain(RandomDomain):
  110. """
  111. A RandomDomain with an attached condition.
  112. See Also
  113. ========
  114. sympy.stats.crv.ConditionalContinuousDomain
  115. sympy.stats.frv.ConditionalFiniteDomain
  116. """
  117. def __new__(cls, fulldomain, condition):
  118. condition = condition.xreplace({rs: rs.symbol
  119. for rs in random_symbols(condition)})
  120. return Basic.__new__(cls, fulldomain, condition)
  121. @property
  122. def symbols(self):
  123. return self.fulldomain.symbols
  124. @property
  125. def fulldomain(self):
  126. return self.args[0]
  127. @property
  128. def condition(self):
  129. return self.args[1]
  130. @property
  131. def set(self):
  132. raise NotImplementedError("Set of Conditional Domain not Implemented")
  133. def as_boolean(self):
  134. return And(self.fulldomain.as_boolean(), self.condition)
  135. class PSpace(Basic):
  136. """
  137. A Probability Space.
  138. Explanation
  139. ===========
  140. Probability Spaces encode processes that equal different values
  141. probabilistically. These underly Random Symbols which occur in SymPy
  142. expressions and contain the mechanics to evaluate statistical statements.
  143. See Also
  144. ========
  145. sympy.stats.crv.ContinuousPSpace
  146. sympy.stats.frv.FinitePSpace
  147. """
  148. is_Finite = None # type: bool
  149. is_Continuous = None # type: bool
  150. is_Discrete = None # type: bool
  151. is_real = None # type: bool
  152. @property
  153. def domain(self):
  154. return self.args[0]
  155. @property
  156. def density(self):
  157. return self.args[1]
  158. @property
  159. def values(self):
  160. return frozenset(RandomSymbol(sym, self) for sym in self.symbols)
  161. @property
  162. def symbols(self):
  163. return self.domain.symbols
  164. def where(self, condition):
  165. raise NotImplementedError()
  166. def compute_density(self, expr):
  167. raise NotImplementedError()
  168. def sample(self, size=(), library='scipy', seed=None):
  169. raise NotImplementedError()
  170. def probability(self, condition):
  171. raise NotImplementedError()
  172. def compute_expectation(self, expr):
  173. raise NotImplementedError()
  174. class SinglePSpace(PSpace):
  175. """
  176. Represents the probabilities of a set of random events that can be
  177. attributed to a single variable/symbol.
  178. """
  179. def __new__(cls, s, distribution):
  180. s = _symbol_converter(s)
  181. return Basic.__new__(cls, s, distribution)
  182. @property
  183. def value(self):
  184. return RandomSymbol(self.symbol, self)
  185. @property
  186. def symbol(self):
  187. return self.args[0]
  188. @property
  189. def distribution(self):
  190. return self.args[1]
  191. @property
  192. def pdf(self):
  193. return self.distribution.pdf(self.symbol)
  194. class RandomSymbol(Expr):
  195. """
  196. Random Symbols represent ProbabilitySpaces in SymPy Expressions.
  197. In principle they can take on any value that their symbol can take on
  198. within the associated PSpace with probability determined by the PSpace
  199. Density.
  200. Explanation
  201. ===========
  202. Random Symbols contain pspace and symbol properties.
  203. The pspace property points to the represented Probability Space
  204. The symbol is a standard SymPy Symbol that is used in that probability space
  205. for example in defining a density.
  206. You can form normal SymPy expressions using RandomSymbols and operate on
  207. those expressions with the Functions
  208. E - Expectation of a random expression
  209. P - Probability of a condition
  210. density - Probability Density of an expression
  211. given - A new random expression (with new random symbols) given a condition
  212. An object of the RandomSymbol type should almost never be created by the
  213. user. They tend to be created instead by the PSpace class's value method.
  214. Traditionally a user doesn't even do this but instead calls one of the
  215. convenience functions Normal, Exponential, Coin, Die, FiniteRV, etc....
  216. """
  217. def __new__(cls, symbol, pspace=None):
  218. from sympy.stats.joint_rv import JointRandomSymbol
  219. if pspace is None:
  220. # Allow single arg, representing pspace == PSpace()
  221. pspace = PSpace()
  222. symbol = _symbol_converter(symbol)
  223. if not isinstance(pspace, PSpace):
  224. raise TypeError("pspace variable should be of type PSpace")
  225. if cls == JointRandomSymbol and isinstance(pspace, SinglePSpace):
  226. cls = RandomSymbol
  227. return Basic.__new__(cls, symbol, pspace)
  228. is_finite = True
  229. is_symbol = True
  230. is_Atom = True
  231. _diff_wrt = True
  232. pspace = property(lambda self: self.args[1])
  233. symbol = property(lambda self: self.args[0])
  234. name = property(lambda self: self.symbol.name)
  235. def _eval_is_positive(self):
  236. return self.symbol.is_positive
  237. def _eval_is_integer(self):
  238. return self.symbol.is_integer
  239. def _eval_is_real(self):
  240. return self.symbol.is_real or self.pspace.is_real
  241. @property
  242. def is_commutative(self):
  243. return self.symbol.is_commutative
  244. @property
  245. def free_symbols(self):
  246. return {self}
  247. class RandomIndexedSymbol(RandomSymbol):
  248. def __new__(cls, idx_obj, pspace=None):
  249. if pspace is None:
  250. # Allow single arg, representing pspace == PSpace()
  251. pspace = PSpace()
  252. if not isinstance(idx_obj, (Indexed, Function)):
  253. raise TypeError("An Function or Indexed object is expected not %s"%(idx_obj))
  254. return Basic.__new__(cls, idx_obj, pspace)
  255. symbol = property(lambda self: self.args[0])
  256. name = property(lambda self: str(self.args[0]))
  257. @property
  258. def key(self):
  259. if isinstance(self.symbol, Indexed):
  260. return self.symbol.args[1]
  261. elif isinstance(self.symbol, Function):
  262. return self.symbol.args[0]
  263. @property
  264. def free_symbols(self):
  265. if self.key.free_symbols:
  266. free_syms = self.key.free_symbols
  267. free_syms.add(self)
  268. return free_syms
  269. return {self}
  270. @property
  271. def pspace(self):
  272. return self.args[1]
  273. class RandomMatrixSymbol(RandomSymbol, MatrixSymbol): # type: ignore
  274. def __new__(cls, symbol, n, m, pspace=None):
  275. n, m = _sympify(n), _sympify(m)
  276. symbol = _symbol_converter(symbol)
  277. if pspace is None:
  278. # Allow single arg, representing pspace == PSpace()
  279. pspace = PSpace()
  280. return Basic.__new__(cls, symbol, n, m, pspace)
  281. symbol = property(lambda self: self.args[0])
  282. pspace = property(lambda self: self.args[3])
  283. class ProductPSpace(PSpace):
  284. """
  285. Abstract class for representing probability spaces with multiple random
  286. variables.
  287. See Also
  288. ========
  289. sympy.stats.rv.IndependentProductPSpace
  290. sympy.stats.joint_rv.JointPSpace
  291. """
  292. pass
  293. class IndependentProductPSpace(ProductPSpace):
  294. """
  295. A probability space resulting from the merger of two independent probability
  296. spaces.
  297. Often created using the function, pspace.
  298. """
  299. def __new__(cls, *spaces):
  300. rs_space_dict = {}
  301. for space in spaces:
  302. for value in space.values:
  303. rs_space_dict[value] = space
  304. symbols = FiniteSet(*[val.symbol for val in rs_space_dict.keys()])
  305. # Overlapping symbols
  306. from sympy.stats.joint_rv import MarginalDistribution
  307. from sympy.stats.compound_rv import CompoundDistribution
  308. if len(symbols) < sum(len(space.symbols) for space in spaces if not
  309. isinstance(space.distribution, (
  310. CompoundDistribution, MarginalDistribution))):
  311. raise ValueError("Overlapping Random Variables")
  312. if all(space.is_Finite for space in spaces):
  313. from sympy.stats.frv import ProductFinitePSpace
  314. cls = ProductFinitePSpace
  315. obj = Basic.__new__(cls, *FiniteSet(*spaces))
  316. return obj
  317. @property
  318. def pdf(self):
  319. p = Mul(*[space.pdf for space in self.spaces])
  320. return p.subs({rv: rv.symbol for rv in self.values})
  321. @property
  322. def rs_space_dict(self):
  323. d = {}
  324. for space in self.spaces:
  325. for value in space.values:
  326. d[value] = space
  327. return d
  328. @property
  329. def symbols(self):
  330. return FiniteSet(*[val.symbol for val in self.rs_space_dict.keys()])
  331. @property
  332. def spaces(self):
  333. return FiniteSet(*self.args)
  334. @property
  335. def values(self):
  336. return sumsets(space.values for space in self.spaces)
  337. def compute_expectation(self, expr, rvs=None, evaluate=False, **kwargs):
  338. rvs = rvs or self.values
  339. rvs = frozenset(rvs)
  340. for space in self.spaces:
  341. expr = space.compute_expectation(expr, rvs & space.values, evaluate=False, **kwargs)
  342. if evaluate and hasattr(expr, 'doit'):
  343. return expr.doit(**kwargs)
  344. return expr
  345. @property
  346. def domain(self):
  347. return ProductDomain(*[space.domain for space in self.spaces])
  348. @property
  349. def density(self):
  350. raise NotImplementedError("Density not available for ProductSpaces")
  351. def sample(self, size=(), library='scipy', seed=None):
  352. return {k: v for space in self.spaces
  353. for k, v in space.sample(size=size, library=library, seed=seed).items()}
  354. def probability(self, condition, **kwargs):
  355. cond_inv = False
  356. if isinstance(condition, Ne):
  357. condition = Eq(condition.args[0], condition.args[1])
  358. cond_inv = True
  359. elif isinstance(condition, And): # they are independent
  360. return Mul(*[self.probability(arg) for arg in condition.args])
  361. elif isinstance(condition, Or): # they are independent
  362. return Add(*[self.probability(arg) for arg in condition.args])
  363. expr = condition.lhs - condition.rhs
  364. rvs = random_symbols(expr)
  365. dens = self.compute_density(expr)
  366. if any(pspace(rv).is_Continuous for rv in rvs):
  367. from sympy.stats.crv import SingleContinuousPSpace
  368. from sympy.stats.crv_types import ContinuousDistributionHandmade
  369. if expr in self.values:
  370. # Marginalize all other random symbols out of the density
  371. randomsymbols = tuple(set(self.values) - frozenset([expr]))
  372. symbols = tuple(rs.symbol for rs in randomsymbols)
  373. pdf = self.domain.integrate(self.pdf, symbols, **kwargs)
  374. return Lambda(expr.symbol, pdf)
  375. dens = ContinuousDistributionHandmade(dens)
  376. z = Dummy('z', real=True)
  377. space = SingleContinuousPSpace(z, dens)
  378. result = space.probability(condition.__class__(space.value, 0))
  379. else:
  380. from sympy.stats.drv import SingleDiscretePSpace
  381. from sympy.stats.drv_types import DiscreteDistributionHandmade
  382. dens = DiscreteDistributionHandmade(dens)
  383. z = Dummy('z', integer=True)
  384. space = SingleDiscretePSpace(z, dens)
  385. result = space.probability(condition.__class__(space.value, 0))
  386. return result if not cond_inv else S.One - result
  387. def compute_density(self, expr, **kwargs):
  388. rvs = random_symbols(expr)
  389. if any(pspace(rv).is_Continuous for rv in rvs):
  390. z = Dummy('z', real=True)
  391. expr = self.compute_expectation(DiracDelta(expr - z),
  392. **kwargs)
  393. else:
  394. z = Dummy('z', integer=True)
  395. expr = self.compute_expectation(KroneckerDelta(expr, z),
  396. **kwargs)
  397. return Lambda(z, expr)
  398. def compute_cdf(self, expr, **kwargs):
  399. raise ValueError("CDF not well defined on multivariate expressions")
  400. def conditional_space(self, condition, normalize=True, **kwargs):
  401. rvs = random_symbols(condition)
  402. condition = condition.xreplace({rv: rv.symbol for rv in self.values})
  403. pspaces = [pspace(rv) for rv in rvs]
  404. if any(ps.is_Continuous for ps in pspaces):
  405. from sympy.stats.crv import (ConditionalContinuousDomain,
  406. ContinuousPSpace)
  407. space = ContinuousPSpace
  408. domain = ConditionalContinuousDomain(self.domain, condition)
  409. elif any(ps.is_Discrete for ps in pspaces):
  410. from sympy.stats.drv import (ConditionalDiscreteDomain,
  411. DiscretePSpace)
  412. space = DiscretePSpace
  413. domain = ConditionalDiscreteDomain(self.domain, condition)
  414. elif all(ps.is_Finite for ps in pspaces):
  415. from sympy.stats.frv import FinitePSpace
  416. return FinitePSpace.conditional_space(self, condition)
  417. if normalize:
  418. replacement = {rv: Dummy(str(rv)) for rv in self.symbols}
  419. norm = domain.compute_expectation(self.pdf, **kwargs)
  420. pdf = self.pdf / norm.xreplace(replacement)
  421. # XXX: Converting symbols from set to tuple. The order matters to
  422. # Lambda though so we shouldn't be starting with a set here...
  423. density = Lambda(tuple(domain.symbols), pdf)
  424. return space(domain, density)
  425. class ProductDomain(RandomDomain):
  426. """
  427. A domain resulting from the merger of two independent domains.
  428. See Also
  429. ========
  430. sympy.stats.crv.ProductContinuousDomain
  431. sympy.stats.frv.ProductFiniteDomain
  432. """
  433. is_ProductDomain = True
  434. def __new__(cls, *domains):
  435. # Flatten any product of products
  436. domains2 = []
  437. for domain in domains:
  438. if not domain.is_ProductDomain:
  439. domains2.append(domain)
  440. else:
  441. domains2.extend(domain.domains)
  442. domains2 = FiniteSet(*domains2)
  443. if all(domain.is_Finite for domain in domains2):
  444. from sympy.stats.frv import ProductFiniteDomain
  445. cls = ProductFiniteDomain
  446. if all(domain.is_Continuous for domain in domains2):
  447. from sympy.stats.crv import ProductContinuousDomain
  448. cls = ProductContinuousDomain
  449. if all(domain.is_Discrete for domain in domains2):
  450. from sympy.stats.drv import ProductDiscreteDomain
  451. cls = ProductDiscreteDomain
  452. return Basic.__new__(cls, *domains2)
  453. @property
  454. def sym_domain_dict(self):
  455. return {symbol: domain for domain in self.domains
  456. for symbol in domain.symbols}
  457. @property
  458. def symbols(self):
  459. return FiniteSet(*[sym for domain in self.domains
  460. for sym in domain.symbols])
  461. @property
  462. def domains(self):
  463. return self.args
  464. @property
  465. def set(self):
  466. return ProductSet(*(domain.set for domain in self.domains))
  467. def __contains__(self, other):
  468. # Split event into each subdomain
  469. for domain in self.domains:
  470. # Collect the parts of this event which associate to this domain
  471. elem = frozenset([item for item in other
  472. if sympify(domain.symbols.contains(item[0]))
  473. is S.true])
  474. # Test this sub-event
  475. if elem not in domain:
  476. return False
  477. # All subevents passed
  478. return True
  479. def as_boolean(self):
  480. return And(*[domain.as_boolean() for domain in self.domains])
  481. def random_symbols(expr):
  482. """
  483. Returns all RandomSymbols within a SymPy Expression.
  484. """
  485. atoms = getattr(expr, 'atoms', None)
  486. if atoms is not None:
  487. comp = lambda rv: rv.symbol.name
  488. l = list(atoms(RandomSymbol))
  489. return sorted(l, key=comp)
  490. else:
  491. return []
  492. def pspace(expr):
  493. """
  494. Returns the underlying Probability Space of a random expression.
  495. For internal use.
  496. Examples
  497. ========
  498. >>> from sympy.stats import pspace, Normal
  499. >>> X = Normal('X', 0, 1)
  500. >>> pspace(2*X + 1) == X.pspace
  501. True
  502. """
  503. expr = sympify(expr)
  504. if isinstance(expr, RandomSymbol) and expr.pspace is not None:
  505. return expr.pspace
  506. if expr.has(RandomMatrixSymbol):
  507. rm = list(expr.atoms(RandomMatrixSymbol))[0]
  508. return rm.pspace
  509. rvs = random_symbols(expr)
  510. if not rvs:
  511. raise ValueError("Expression containing Random Variable expected, not %s" % (expr))
  512. # If only one space present
  513. if all(rv.pspace == rvs[0].pspace for rv in rvs):
  514. return rvs[0].pspace
  515. from sympy.stats.compound_rv import CompoundPSpace
  516. from sympy.stats.stochastic_process import StochasticPSpace
  517. for rv in rvs:
  518. if isinstance(rv.pspace, (CompoundPSpace, StochasticPSpace)):
  519. return rv.pspace
  520. # Otherwise make a product space
  521. return IndependentProductPSpace(*[rv.pspace for rv in rvs])
  522. def sumsets(sets):
  523. """
  524. Union of sets
  525. """
  526. return frozenset().union(*sets)
  527. def rs_swap(a, b):
  528. """
  529. Build a dictionary to swap RandomSymbols based on their underlying symbol.
  530. i.e.
  531. if ``X = ('x', pspace1)``
  532. and ``Y = ('x', pspace2)``
  533. then ``X`` and ``Y`` match and the key, value pair
  534. ``{X:Y}`` will appear in the result
  535. Inputs: collections a and b of random variables which share common symbols
  536. Output: dict mapping RVs in a to RVs in b
  537. """
  538. d = {}
  539. for rsa in a:
  540. d[rsa] = [rsb for rsb in b if rsa.symbol == rsb.symbol][0]
  541. return d
  542. def given(expr, condition=None, **kwargs):
  543. r""" Conditional Random Expression.
  544. Explanation
  545. ===========
  546. From a random expression and a condition on that expression creates a new
  547. probability space from the condition and returns the same expression on that
  548. conditional probability space.
  549. Examples
  550. ========
  551. >>> from sympy.stats import given, density, Die
  552. >>> X = Die('X', 6)
  553. >>> Y = given(X, X > 3)
  554. >>> density(Y).dict
  555. {4: 1/3, 5: 1/3, 6: 1/3}
  556. Following convention, if the condition is a random symbol then that symbol
  557. is considered fixed.
  558. >>> from sympy.stats import Normal
  559. >>> from sympy import pprint
  560. >>> from sympy.abc import z
  561. >>> X = Normal('X', 0, 1)
  562. >>> Y = Normal('Y', 0, 1)
  563. >>> pprint(density(X + Y, Y)(z), use_unicode=False)
  564. 2
  565. -(-Y + z)
  566. -----------
  567. ___ 2
  568. \/ 2 *e
  569. ------------------
  570. ____
  571. 2*\/ pi
  572. """
  573. if not is_random(condition) or pspace_independent(expr, condition):
  574. return expr
  575. if isinstance(condition, RandomSymbol):
  576. condition = Eq(condition, condition.symbol)
  577. condsymbols = random_symbols(condition)
  578. if (isinstance(condition, Eq) and len(condsymbols) == 1 and
  579. not isinstance(pspace(expr).domain, ConditionalDomain)):
  580. rv = tuple(condsymbols)[0]
  581. results = solveset(condition, rv)
  582. if isinstance(results, Intersection) and S.Reals in results.args:
  583. results = list(results.args[1])
  584. sums = 0
  585. for res in results:
  586. temp = expr.subs(rv, res)
  587. if temp == True:
  588. return True
  589. if temp != False:
  590. # XXX: This seems nonsensical but preserves existing behaviour
  591. # after the change that Relational is no longer a subclass of
  592. # Expr. Here expr is sometimes Relational and sometimes Expr
  593. # but we are trying to add them with +=. This needs to be
  594. # fixed somehow.
  595. if sums == 0 and isinstance(expr, Relational):
  596. sums = expr.subs(rv, res)
  597. else:
  598. sums += expr.subs(rv, res)
  599. if sums == 0:
  600. return False
  601. return sums
  602. # Get full probability space of both the expression and the condition
  603. fullspace = pspace(Tuple(expr, condition))
  604. # Build new space given the condition
  605. space = fullspace.conditional_space(condition, **kwargs)
  606. # Dictionary to swap out RandomSymbols in expr with new RandomSymbols
  607. # That point to the new conditional space
  608. swapdict = rs_swap(fullspace.values, space.values)
  609. # Swap random variables in the expression
  610. expr = expr.xreplace(swapdict)
  611. return expr
  612. def expectation(expr, condition=None, numsamples=None, evaluate=True, **kwargs):
  613. """
  614. Returns the expected value of a random expression.
  615. Parameters
  616. ==========
  617. expr : Expr containing RandomSymbols
  618. The expression of which you want to compute the expectation value
  619. given : Expr containing RandomSymbols
  620. A conditional expression. E(X, X>0) is expectation of X given X > 0
  621. numsamples : int
  622. Enables sampling and approximates the expectation with this many samples
  623. evalf : Bool (defaults to True)
  624. If sampling return a number rather than a complex expression
  625. evaluate : Bool (defaults to True)
  626. In case of continuous systems return unevaluated integral
  627. Examples
  628. ========
  629. >>> from sympy.stats import E, Die
  630. >>> X = Die('X', 6)
  631. >>> E(X)
  632. 7/2
  633. >>> E(2*X + 1)
  634. 8
  635. >>> E(X, X > 3) # Expectation of X given that it is above 3
  636. 5
  637. """
  638. if not is_random(expr): # expr isn't random?
  639. return expr
  640. kwargs['numsamples'] = numsamples
  641. from sympy.stats.symbolic_probability import Expectation
  642. if evaluate:
  643. return Expectation(expr, condition).doit(**kwargs)
  644. return Expectation(expr, condition)
  645. def probability(condition, given_condition=None, numsamples=None,
  646. evaluate=True, **kwargs):
  647. """
  648. Probability that a condition is true, optionally given a second condition.
  649. Parameters
  650. ==========
  651. condition : Combination of Relationals containing RandomSymbols
  652. The condition of which you want to compute the probability
  653. given_condition : Combination of Relationals containing RandomSymbols
  654. A conditional expression. P(X > 1, X > 0) is expectation of X > 1
  655. given X > 0
  656. numsamples : int
  657. Enables sampling and approximates the probability with this many samples
  658. evaluate : Bool (defaults to True)
  659. In case of continuous systems return unevaluated integral
  660. Examples
  661. ========
  662. >>> from sympy.stats import P, Die
  663. >>> from sympy import Eq
  664. >>> X, Y = Die('X', 6), Die('Y', 6)
  665. >>> P(X > 3)
  666. 1/2
  667. >>> P(Eq(X, 5), X > 2) # Probability that X == 5 given that X > 2
  668. 1/4
  669. >>> P(X > Y)
  670. 5/12
  671. """
  672. kwargs['numsamples'] = numsamples
  673. from sympy.stats.symbolic_probability import Probability
  674. if evaluate:
  675. return Probability(condition, given_condition).doit(**kwargs)
  676. ### TODO: Remove the user warnings in the future releases
  677. message = ("Since version 1.7, using `evaluate=False` returns `Probability` "
  678. "object. If you want unevaluated Integral/Sum use "
  679. "`P(condition, given_condition, evaluate=False).rewrite(Integral)`")
  680. warnings.warn(filldedent(message))
  681. return Probability(condition, given_condition)
  682. class Density(Basic):
  683. expr = property(lambda self: self.args[0])
  684. def __new__(cls, expr, condition = None):
  685. expr = _sympify(expr)
  686. if condition is None:
  687. obj = Basic.__new__(cls, expr)
  688. else:
  689. condition = _sympify(condition)
  690. obj = Basic.__new__(cls, expr, condition)
  691. return obj
  692. @property
  693. def condition(self):
  694. if len(self.args) > 1:
  695. return self.args[1]
  696. else:
  697. return None
  698. def doit(self, evaluate=True, **kwargs):
  699. from sympy.stats.random_matrix import RandomMatrixPSpace
  700. from sympy.stats.joint_rv import JointPSpace
  701. from sympy.stats.matrix_distributions import MatrixPSpace
  702. from sympy.stats.compound_rv import CompoundPSpace
  703. from sympy.stats.frv import SingleFiniteDistribution
  704. expr, condition = self.expr, self.condition
  705. if isinstance(expr, SingleFiniteDistribution):
  706. return expr.dict
  707. if condition is not None:
  708. # Recompute on new conditional expr
  709. expr = given(expr, condition, **kwargs)
  710. if not random_symbols(expr):
  711. return Lambda(x, DiracDelta(x - expr))
  712. if isinstance(expr, RandomSymbol):
  713. if isinstance(expr.pspace, (SinglePSpace, JointPSpace, MatrixPSpace)) and \
  714. hasattr(expr.pspace, 'distribution'):
  715. return expr.pspace.distribution
  716. elif isinstance(expr.pspace, RandomMatrixPSpace):
  717. return expr.pspace.model
  718. if isinstance(pspace(expr), CompoundPSpace):
  719. kwargs['compound_evaluate'] = evaluate
  720. result = pspace(expr).compute_density(expr, **kwargs)
  721. if evaluate and hasattr(result, 'doit'):
  722. return result.doit()
  723. else:
  724. return result
  725. def density(expr, condition=None, evaluate=True, numsamples=None, **kwargs):
  726. """
  727. Probability density of a random expression, optionally given a second
  728. condition.
  729. Explanation
  730. ===========
  731. This density will take on different forms for different types of
  732. probability spaces. Discrete variables produce Dicts. Continuous
  733. variables produce Lambdas.
  734. Parameters
  735. ==========
  736. expr : Expr containing RandomSymbols
  737. The expression of which you want to compute the density value
  738. condition : Relational containing RandomSymbols
  739. A conditional expression. density(X > 1, X > 0) is density of X > 1
  740. given X > 0
  741. numsamples : int
  742. Enables sampling and approximates the density with this many samples
  743. Examples
  744. ========
  745. >>> from sympy.stats import density, Die, Normal
  746. >>> from sympy import Symbol
  747. >>> x = Symbol('x')
  748. >>> D = Die('D', 6)
  749. >>> X = Normal(x, 0, 1)
  750. >>> density(D).dict
  751. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  752. >>> density(2*D).dict
  753. {2: 1/6, 4: 1/6, 6: 1/6, 8: 1/6, 10: 1/6, 12: 1/6}
  754. >>> density(X)(x)
  755. sqrt(2)*exp(-x**2/2)/(2*sqrt(pi))
  756. """
  757. if numsamples:
  758. return sampling_density(expr, condition, numsamples=numsamples,
  759. **kwargs)
  760. return Density(expr, condition).doit(evaluate=evaluate, **kwargs)
  761. def cdf(expr, condition=None, evaluate=True, **kwargs):
  762. """
  763. Cumulative Distribution Function of a random expression.
  764. optionally given a second condition.
  765. Explanation
  766. ===========
  767. This density will take on different forms for different types of
  768. probability spaces.
  769. Discrete variables produce Dicts.
  770. Continuous variables produce Lambdas.
  771. Examples
  772. ========
  773. >>> from sympy.stats import density, Die, Normal, cdf
  774. >>> D = Die('D', 6)
  775. >>> X = Normal('X', 0, 1)
  776. >>> density(D).dict
  777. {1: 1/6, 2: 1/6, 3: 1/6, 4: 1/6, 5: 1/6, 6: 1/6}
  778. >>> cdf(D)
  779. {1: 1/6, 2: 1/3, 3: 1/2, 4: 2/3, 5: 5/6, 6: 1}
  780. >>> cdf(3*D, D > 2)
  781. {9: 1/4, 12: 1/2, 15: 3/4, 18: 1}
  782. >>> cdf(X)
  783. Lambda(_z, erf(sqrt(2)*_z/2)/2 + 1/2)
  784. """
  785. if condition is not None: # If there is a condition
  786. # Recompute on new conditional expr
  787. return cdf(given(expr, condition, **kwargs), **kwargs)
  788. # Otherwise pass work off to the ProbabilitySpace
  789. result = pspace(expr).compute_cdf(expr, **kwargs)
  790. if evaluate and hasattr(result, 'doit'):
  791. return result.doit()
  792. else:
  793. return result
  794. def characteristic_function(expr, condition=None, evaluate=True, **kwargs):
  795. """
  796. Characteristic function of a random expression, optionally given a second condition.
  797. Returns a Lambda.
  798. Examples
  799. ========
  800. >>> from sympy.stats import Normal, DiscreteUniform, Poisson, characteristic_function
  801. >>> X = Normal('X', 0, 1)
  802. >>> characteristic_function(X)
  803. Lambda(_t, exp(-_t**2/2))
  804. >>> Y = DiscreteUniform('Y', [1, 2, 7])
  805. >>> characteristic_function(Y)
  806. Lambda(_t, exp(7*_t*I)/3 + exp(2*_t*I)/3 + exp(_t*I)/3)
  807. >>> Z = Poisson('Z', 2)
  808. >>> characteristic_function(Z)
  809. Lambda(_t, exp(2*exp(_t*I) - 2))
  810. """
  811. if condition is not None:
  812. return characteristic_function(given(expr, condition, **kwargs), **kwargs)
  813. result = pspace(expr).compute_characteristic_function(expr, **kwargs)
  814. if evaluate and hasattr(result, 'doit'):
  815. return result.doit()
  816. else:
  817. return result
  818. def moment_generating_function(expr, condition=None, evaluate=True, **kwargs):
  819. if condition is not None:
  820. return moment_generating_function(given(expr, condition, **kwargs), **kwargs)
  821. result = pspace(expr).compute_moment_generating_function(expr, **kwargs)
  822. if evaluate and hasattr(result, 'doit'):
  823. return result.doit()
  824. else:
  825. return result
  826. def where(condition, given_condition=None, **kwargs):
  827. """
  828. Returns the domain where a condition is True.
  829. Examples
  830. ========
  831. >>> from sympy.stats import where, Die, Normal
  832. >>> from sympy import And
  833. >>> D1, D2 = Die('a', 6), Die('b', 6)
  834. >>> a, b = D1.symbol, D2.symbol
  835. >>> X = Normal('x', 0, 1)
  836. >>> where(X**2<1)
  837. Domain: (-1 < x) & (x < 1)
  838. >>> where(X**2<1).set
  839. Interval.open(-1, 1)
  840. >>> where(And(D1<=D2, D2<3))
  841. Domain: (Eq(a, 1) & Eq(b, 1)) | (Eq(a, 1) & Eq(b, 2)) | (Eq(a, 2) & Eq(b, 2))
  842. """
  843. if given_condition is not None: # If there is a condition
  844. # Recompute on new conditional expr
  845. return where(given(condition, given_condition, **kwargs), **kwargs)
  846. # Otherwise pass work off to the ProbabilitySpace
  847. return pspace(condition).where(condition, **kwargs)
  848. @doctest_depends_on(modules=('scipy',))
  849. def sample(expr, condition=None, size=(), library='scipy',
  850. numsamples=1, seed=None, **kwargs):
  851. """
  852. A realization of the random expression.
  853. Parameters
  854. ==========
  855. expr : Expression of random variables
  856. Expression from which sample is extracted
  857. condition : Expr containing RandomSymbols
  858. A conditional expression
  859. size : int, tuple
  860. Represents size of each sample in numsamples
  861. library : str
  862. - 'scipy' : Sample using scipy
  863. - 'numpy' : Sample using numpy
  864. - 'pymc3' : Sample using PyMC3
  865. Choose any of the available options to sample from as string,
  866. by default is 'scipy'
  867. numsamples : int
  868. Number of samples, each with size as ``size``.
  869. .. deprecated:: 1.9
  870. The ``numsamples`` parameter is deprecated and is only provided for
  871. compatibility with v1.8. Use a list comprehension or an additional
  872. dimension in ``size`` instead. See
  873. :ref:`deprecated-sympy-stats-numsamples` for details.
  874. seed :
  875. An object to be used as seed by the given external library for sampling `expr`.
  876. Following is the list of possible types of object for the supported libraries,
  877. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  878. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  879. - 'pymc3': int
  880. Optional, by default None, in which case seed settings
  881. related to the given library will be used.
  882. No modifications to environment's global seed settings
  883. are done by this argument.
  884. Returns
  885. =======
  886. sample: float/list/numpy.ndarray
  887. one sample or a collection of samples of the random expression.
  888. - sample(X) returns float/numpy.float64/numpy.int64 object.
  889. - sample(X, size=int/tuple) returns numpy.ndarray object.
  890. Examples
  891. ========
  892. >>> from sympy.stats import Die, sample, Normal, Geometric
  893. >>> X, Y, Z = Die('X', 6), Die('Y', 6), Die('Z', 6) # Finite Random Variable
  894. >>> die_roll = sample(X + Y + Z)
  895. >>> die_roll # doctest: +SKIP
  896. 3
  897. >>> N = Normal('N', 3, 4) # Continuous Random Variable
  898. >>> samp = sample(N)
  899. >>> samp in N.pspace.domain.set
  900. True
  901. >>> samp = sample(N, N>0)
  902. >>> samp > 0
  903. True
  904. >>> samp_list = sample(N, size=4)
  905. >>> [sam in N.pspace.domain.set for sam in samp_list]
  906. [True, True, True, True]
  907. >>> sample(N, size = (2,3)) # doctest: +SKIP
  908. array([[5.42519758, 6.40207856, 4.94991743],
  909. [1.85819627, 6.83403519, 1.9412172 ]])
  910. >>> G = Geometric('G', 0.5) # Discrete Random Variable
  911. >>> samp_list = sample(G, size=3)
  912. >>> samp_list # doctest: +SKIP
  913. [1, 3, 2]
  914. >>> [sam in G.pspace.domain.set for sam in samp_list]
  915. [True, True, True]
  916. >>> MN = Normal("MN", [3, 4], [[2, 1], [1, 2]]) # Joint Random Variable
  917. >>> samp_list = sample(MN, size=4)
  918. >>> samp_list # doctest: +SKIP
  919. [array([2.85768055, 3.38954165]),
  920. array([4.11163337, 4.3176591 ]),
  921. array([0.79115232, 1.63232916]),
  922. array([4.01747268, 3.96716083])]
  923. >>> [tuple(sam) in MN.pspace.domain.set for sam in samp_list]
  924. [True, True, True, True]
  925. .. versionchanged:: 1.7.0
  926. sample used to return an iterator containing the samples instead of value.
  927. .. versionchanged:: 1.9.0
  928. sample returns values or array of values instead of an iterator and numsamples is deprecated.
  929. """
  930. iterator = sample_iter(expr, condition, size=size, library=library,
  931. numsamples=numsamples, seed=seed)
  932. if numsamples != 1:
  933. sympy_deprecation_warning(
  934. f"""
  935. The numsamples parameter to sympy.stats.sample() is deprecated.
  936. Either use a list comprehension, like
  937. [sample(...) for i in range({numsamples})]
  938. or add a dimension to size, like
  939. sample(..., size={(numsamples,) + size})
  940. """,
  941. deprecated_since_version="1.9",
  942. active_deprecations_target="deprecated-sympy-stats-numsamples",
  943. )
  944. return [next(iterator) for i in range(numsamples)]
  945. return next(iterator)
  946. def quantile(expr, evaluate=True, **kwargs):
  947. r"""
  948. Return the :math:`p^{th}` order quantile of a probability distribution.
  949. Explanation
  950. ===========
  951. Quantile is defined as the value at which the probability of the random
  952. variable is less than or equal to the given probability.
  953. ..math::
  954. Q(p) = inf{x \in (-\infty, \infty) such that p <= F(x)}
  955. Examples
  956. ========
  957. >>> from sympy.stats import quantile, Die, Exponential
  958. >>> from sympy import Symbol, pprint
  959. >>> p = Symbol("p")
  960. >>> l = Symbol("lambda", positive=True)
  961. >>> X = Exponential("x", l)
  962. >>> quantile(X)(p)
  963. -log(1 - p)/lambda
  964. >>> D = Die("d", 6)
  965. >>> pprint(quantile(D)(p), use_unicode=False)
  966. /nan for Or(p > 1, p < 0)
  967. |
  968. | 1 for p <= 1/6
  969. |
  970. | 2 for p <= 1/3
  971. |
  972. < 3 for p <= 1/2
  973. |
  974. | 4 for p <= 2/3
  975. |
  976. | 5 for p <= 5/6
  977. |
  978. \ 6 for p <= 1
  979. """
  980. result = pspace(expr).compute_quantile(expr, **kwargs)
  981. if evaluate and hasattr(result, 'doit'):
  982. return result.doit()
  983. else:
  984. return result
  985. def sample_iter(expr, condition=None, size=(), library='scipy',
  986. numsamples=S.Infinity, seed=None, **kwargs):
  987. """
  988. Returns an iterator of realizations from the expression given a condition.
  989. Parameters
  990. ==========
  991. expr: Expr
  992. Random expression to be realized
  993. condition: Expr, optional
  994. A conditional expression
  995. size : int, tuple
  996. Represents size of each sample in numsamples
  997. numsamples: integer, optional
  998. Length of the iterator (defaults to infinity)
  999. seed :
  1000. An object to be used as seed by the given external library for sampling `expr`.
  1001. Following is the list of possible types of object for the supported libraries,
  1002. - 'scipy': int, numpy.random.RandomState, numpy.random.Generator
  1003. - 'numpy': int, numpy.random.RandomState, numpy.random.Generator
  1004. - 'pymc3': int
  1005. Optional, by default None, in which case seed settings
  1006. related to the given library will be used.
  1007. No modifications to environment's global seed settings
  1008. are done by this argument.
  1009. Examples
  1010. ========
  1011. >>> from sympy.stats import Normal, sample_iter
  1012. >>> X = Normal('X', 0, 1)
  1013. >>> expr = X*X + 3
  1014. >>> iterator = sample_iter(expr, numsamples=3) # doctest: +SKIP
  1015. >>> list(iterator) # doctest: +SKIP
  1016. [12, 4, 7]
  1017. Returns
  1018. =======
  1019. sample_iter: iterator object
  1020. iterator object containing the sample/samples of given expr
  1021. See Also
  1022. ========
  1023. sample
  1024. sampling_P
  1025. sampling_E
  1026. """
  1027. from sympy.stats.joint_rv import JointRandomSymbol
  1028. if not import_module(library):
  1029. raise ValueError("Failed to import %s" % library)
  1030. if condition is not None:
  1031. ps = pspace(Tuple(expr, condition))
  1032. else:
  1033. ps = pspace(expr)
  1034. rvs = list(ps.values)
  1035. if isinstance(expr, JointRandomSymbol):
  1036. expr = expr.subs({expr: RandomSymbol(expr.symbol, expr.pspace)})
  1037. else:
  1038. sub = {}
  1039. for arg in expr.args:
  1040. if isinstance(arg, JointRandomSymbol):
  1041. sub[arg] = RandomSymbol(arg.symbol, arg.pspace)
  1042. expr = expr.subs(sub)
  1043. def fn_subs(*args):
  1044. return expr.subs({rv: arg for rv, arg in zip(rvs, args)})
  1045. def given_fn_subs(*args):
  1046. if condition is not None:
  1047. return condition.subs({rv: arg for rv, arg in zip(rvs, args)})
  1048. return False
  1049. if library == 'pymc3':
  1050. # Currently unable to lambdify in pymc3
  1051. # TODO : Remove 'pymc3' when lambdify accepts 'pymc3' as module
  1052. fn = lambdify(rvs, expr, **kwargs)
  1053. else:
  1054. fn = lambdify(rvs, expr, modules=library, **kwargs)
  1055. if condition is not None:
  1056. given_fn = lambdify(rvs, condition, **kwargs)
  1057. def return_generator_infinite():
  1058. count = 0
  1059. _size = (1,)+((size,) if isinstance(size, int) else size)
  1060. while count < numsamples:
  1061. d = ps.sample(size=_size, library=library, seed=seed) # a dictionary that maps RVs to values
  1062. args = [d[rv][0] for rv in rvs]
  1063. if condition is not None: # Check that these values satisfy the condition
  1064. # TODO: Replace the try-except block with only given_fn(*args)
  1065. # once lambdify works with unevaluated SymPy objects.
  1066. try:
  1067. gd = given_fn(*args)
  1068. except (NameError, TypeError):
  1069. gd = given_fn_subs(*args)
  1070. if gd != True and gd != False:
  1071. raise ValueError(
  1072. "Conditions must not contain free symbols")
  1073. if not gd: # If the values don't satisfy then try again
  1074. continue
  1075. yield fn(*args)
  1076. count += 1
  1077. def return_generator_finite():
  1078. faulty = True
  1079. while faulty:
  1080. d = ps.sample(size=(numsamples,) + ((size,) if isinstance(size, int) else size),
  1081. library=library, seed=seed) # a dictionary that maps RVs to values
  1082. faulty = False
  1083. count = 0
  1084. while count < numsamples and not faulty:
  1085. args = [d[rv][count] for rv in rvs]
  1086. if condition is not None: # Check that these values satisfy the condition
  1087. # TODO: Replace the try-except block with only given_fn(*args)
  1088. # once lambdify works with unevaluated SymPy objects.
  1089. try:
  1090. gd = given_fn(*args)
  1091. except (NameError, TypeError):
  1092. gd = given_fn_subs(*args)
  1093. if gd != True and gd != False:
  1094. raise ValueError(
  1095. "Conditions must not contain free symbols")
  1096. if not gd: # If the values don't satisfy then try again
  1097. faulty = True
  1098. count += 1
  1099. count = 0
  1100. while count < numsamples:
  1101. args = [d[rv][count] for rv in rvs]
  1102. # TODO: Replace the try-except block with only fn(*args)
  1103. # once lambdify works with unevaluated SymPy objects.
  1104. try:
  1105. yield fn(*args)
  1106. except (NameError, TypeError):
  1107. yield fn_subs(*args)
  1108. count += 1
  1109. if numsamples is S.Infinity:
  1110. return return_generator_infinite()
  1111. return return_generator_finite()
  1112. def sample_iter_lambdify(expr, condition=None, size=(),
  1113. numsamples=S.Infinity, seed=None, **kwargs):
  1114. return sample_iter(expr, condition=condition, size=size,
  1115. numsamples=numsamples, seed=seed, **kwargs)
  1116. def sample_iter_subs(expr, condition=None, size=(),
  1117. numsamples=S.Infinity, seed=None, **kwargs):
  1118. return sample_iter(expr, condition=condition, size=size,
  1119. numsamples=numsamples, seed=seed, **kwargs)
  1120. def sampling_P(condition, given_condition=None, library='scipy', numsamples=1,
  1121. evalf=True, seed=None, **kwargs):
  1122. """
  1123. Sampling version of P.
  1124. See Also
  1125. ========
  1126. P
  1127. sampling_E
  1128. sampling_density
  1129. """
  1130. count_true = 0
  1131. count_false = 0
  1132. samples = sample_iter(condition, given_condition, library=library,
  1133. numsamples=numsamples, seed=seed, **kwargs)
  1134. for sample in samples:
  1135. if sample:
  1136. count_true += 1
  1137. else:
  1138. count_false += 1
  1139. result = S(count_true) / numsamples
  1140. if evalf:
  1141. return result.evalf()
  1142. else:
  1143. return result
  1144. def sampling_E(expr, given_condition=None, library='scipy', numsamples=1,
  1145. evalf=True, seed=None, **kwargs):
  1146. """
  1147. Sampling version of E.
  1148. See Also
  1149. ========
  1150. P
  1151. sampling_P
  1152. sampling_density
  1153. """
  1154. samples = list(sample_iter(expr, given_condition, library=library,
  1155. numsamples=numsamples, seed=seed, **kwargs))
  1156. result = Add(*[samp for samp in samples]) / numsamples
  1157. if evalf:
  1158. return result.evalf()
  1159. else:
  1160. return result
  1161. def sampling_density(expr, given_condition=None, library='scipy',
  1162. numsamples=1, seed=None, **kwargs):
  1163. """
  1164. Sampling version of density.
  1165. See Also
  1166. ========
  1167. density
  1168. sampling_P
  1169. sampling_E
  1170. """
  1171. results = {}
  1172. for result in sample_iter(expr, given_condition, library=library,
  1173. numsamples=numsamples, seed=seed, **kwargs):
  1174. results[result] = results.get(result, 0) + 1
  1175. return results
  1176. def dependent(a, b):
  1177. """
  1178. Dependence of two random expressions.
  1179. Two expressions are independent if knowledge of one does not change
  1180. computations on the other.
  1181. Examples
  1182. ========
  1183. >>> from sympy.stats import Normal, dependent, given
  1184. >>> from sympy import Tuple, Eq
  1185. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1186. >>> dependent(X, Y)
  1187. False
  1188. >>> dependent(2*X + Y, -Y)
  1189. True
  1190. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1191. >>> dependent(X, Y)
  1192. True
  1193. See Also
  1194. ========
  1195. independent
  1196. """
  1197. if pspace_independent(a, b):
  1198. return False
  1199. z = Symbol('z', real=True)
  1200. # Dependent if density is unchanged when one is given information about
  1201. # the other
  1202. return (density(a, Eq(b, z)) != density(a) or
  1203. density(b, Eq(a, z)) != density(b))
  1204. def independent(a, b):
  1205. """
  1206. Independence of two random expressions.
  1207. Two expressions are independent if knowledge of one does not change
  1208. computations on the other.
  1209. Examples
  1210. ========
  1211. >>> from sympy.stats import Normal, independent, given
  1212. >>> from sympy import Tuple, Eq
  1213. >>> X, Y = Normal('X', 0, 1), Normal('Y', 0, 1)
  1214. >>> independent(X, Y)
  1215. True
  1216. >>> independent(2*X + Y, -Y)
  1217. False
  1218. >>> X, Y = given(Tuple(X, Y), Eq(X + Y, 3))
  1219. >>> independent(X, Y)
  1220. False
  1221. See Also
  1222. ========
  1223. dependent
  1224. """
  1225. return not dependent(a, b)
  1226. def pspace_independent(a, b):
  1227. """
  1228. Tests for independence between a and b by checking if their PSpaces have
  1229. overlapping symbols. This is a sufficient but not necessary condition for
  1230. independence and is intended to be used internally.
  1231. Notes
  1232. =====
  1233. pspace_independent(a, b) implies independent(a, b)
  1234. independent(a, b) does not imply pspace_independent(a, b)
  1235. """
  1236. a_symbols = set(pspace(b).symbols)
  1237. b_symbols = set(pspace(a).symbols)
  1238. if len(set(random_symbols(a)).intersection(random_symbols(b))) != 0:
  1239. return False
  1240. if len(a_symbols.intersection(b_symbols)) == 0:
  1241. return True
  1242. return None
  1243. def rv_subs(expr, symbols=None):
  1244. """
  1245. Given a random expression replace all random variables with their symbols.
  1246. If symbols keyword is given restrict the swap to only the symbols listed.
  1247. """
  1248. if symbols is None:
  1249. symbols = random_symbols(expr)
  1250. if not symbols:
  1251. return expr
  1252. swapdict = {rv: rv.symbol for rv in symbols}
  1253. return expr.subs(swapdict)
  1254. class NamedArgsMixin:
  1255. _argnames = () # type: tTuple[str, ...]
  1256. def __getattr__(self, attr):
  1257. try:
  1258. return self.args[self._argnames.index(attr)]
  1259. except ValueError:
  1260. raise AttributeError("'%s' object has no attribute '%s'" % (
  1261. type(self).__name__, attr))
  1262. class Distribution(Basic):
  1263. def sample(self, size=(), library='scipy', seed=None):
  1264. """ A random realization from the distribution """
  1265. module = import_module(library)
  1266. if library in {'scipy', 'numpy', 'pymc3'} and module is None:
  1267. raise ValueError("Failed to import %s" % library)
  1268. if library == 'scipy':
  1269. # scipy does not require map as it can handle using custom distributions.
  1270. # However, we will still use a map where we can.
  1271. # TODO: do this for drv.py and frv.py if necessary.
  1272. # TODO: add more distributions here if there are more
  1273. # See links below referring to sections beginning with "A common parametrization..."
  1274. # I will remove all these comments if everything is ok.
  1275. from sympy.stats.sampling.sample_scipy import do_sample_scipy
  1276. import numpy
  1277. if seed is None or isinstance(seed, int):
  1278. rand_state = numpy.random.default_rng(seed=seed)
  1279. else:
  1280. rand_state = seed
  1281. samps = do_sample_scipy(self, size, rand_state)
  1282. elif library == 'numpy':
  1283. from sympy.stats.sampling.sample_numpy import do_sample_numpy
  1284. import numpy
  1285. if seed is None or isinstance(seed, int):
  1286. rand_state = numpy.random.default_rng(seed=seed)
  1287. else:
  1288. rand_state = seed
  1289. _size = None if size == () else size
  1290. samps = do_sample_numpy(self, _size, rand_state)
  1291. elif library == 'pymc3':
  1292. from sympy.stats.sampling.sample_pymc3 import do_sample_pymc3
  1293. import logging
  1294. logging.getLogger("pymc3").setLevel(logging.ERROR)
  1295. import pymc3
  1296. with pymc3.Model():
  1297. if do_sample_pymc3(self):
  1298. samps = pymc3.sample(draws=prod(size), chains=1, compute_convergence_checks=False,
  1299. progressbar=False, random_seed=seed, return_inferencedata=False)[:]['X']
  1300. samps = samps.reshape(size)
  1301. else:
  1302. samps = None
  1303. else:
  1304. raise NotImplementedError("Sampling from %s is not supported yet."
  1305. % str(library))
  1306. if samps is not None:
  1307. return samps
  1308. raise NotImplementedError(
  1309. "Sampling for %s is not currently implemented from %s"
  1310. % (self, library))
  1311. def _value_check(condition, message):
  1312. """
  1313. Raise a ValueError with message if condition is False, else
  1314. return True if all conditions were True, else False.
  1315. Examples
  1316. ========
  1317. >>> from sympy.stats.rv import _value_check
  1318. >>> from sympy.abc import a, b, c
  1319. >>> from sympy import And, Dummy
  1320. >>> _value_check(2 < 3, '')
  1321. True
  1322. Here, the condition is not False, but it doesn't evaluate to True
  1323. so False is returned (but no error is raised). So checking if the
  1324. return value is True or False will tell you if all conditions were
  1325. evaluated.
  1326. >>> _value_check(a < b, '')
  1327. False
  1328. In this case the condition is False so an error is raised:
  1329. >>> r = Dummy(real=True)
  1330. >>> _value_check(r < r - 1, 'condition is not true')
  1331. Traceback (most recent call last):
  1332. ...
  1333. ValueError: condition is not true
  1334. If no condition of many conditions must be False, they can be
  1335. checked by passing them as an iterable:
  1336. >>> _value_check((a < 0, b < 0, c < 0), '')
  1337. False
  1338. The iterable can be a generator, too:
  1339. >>> _value_check((i < 0 for i in (a, b, c)), '')
  1340. False
  1341. The following are equivalent to the above but do not pass
  1342. an iterable:
  1343. >>> all(_value_check(i < 0, '') for i in (a, b, c))
  1344. False
  1345. >>> _value_check(And(a < 0, b < 0, c < 0), '')
  1346. False
  1347. """
  1348. if not iterable(condition):
  1349. condition = [condition]
  1350. truth = fuzzy_and(condition)
  1351. if truth == False:
  1352. raise ValueError(message)
  1353. return truth == True
  1354. def _symbol_converter(sym):
  1355. """
  1356. Casts the parameter to Symbol if it is 'str'
  1357. otherwise no operation is performed on it.
  1358. Parameters
  1359. ==========
  1360. sym
  1361. The parameter to be converted.
  1362. Returns
  1363. =======
  1364. Symbol
  1365. the parameter converted to Symbol.
  1366. Raises
  1367. ======
  1368. TypeError
  1369. If the parameter is not an instance of both str and
  1370. Symbol.
  1371. Examples
  1372. ========
  1373. >>> from sympy import Symbol
  1374. >>> from sympy.stats.rv import _symbol_converter
  1375. >>> s = _symbol_converter('s')
  1376. >>> isinstance(s, Symbol)
  1377. True
  1378. >>> _symbol_converter(1)
  1379. Traceback (most recent call last):
  1380. ...
  1381. TypeError: 1 is neither a Symbol nor a string
  1382. >>> r = Symbol('r')
  1383. >>> isinstance(r, Symbol)
  1384. True
  1385. """
  1386. if isinstance(sym, str):
  1387. sym = Symbol(sym)
  1388. if not isinstance(sym, Symbol):
  1389. raise TypeError("%s is neither a Symbol nor a string"%(sym))
  1390. return sym
  1391. def sample_stochastic_process(process):
  1392. """
  1393. This function is used to sample from stochastic process.
  1394. Parameters
  1395. ==========
  1396. process: StochasticProcess
  1397. Process used to extract the samples. It must be an instance of
  1398. StochasticProcess
  1399. Examples
  1400. ========
  1401. >>> from sympy.stats import sample_stochastic_process, DiscreteMarkovChain
  1402. >>> from sympy import Matrix
  1403. >>> T = Matrix([[0.5, 0.2, 0.3],[0.2, 0.5, 0.3],[0.2, 0.3, 0.5]])
  1404. >>> Y = DiscreteMarkovChain("Y", [0, 1, 2], T)
  1405. >>> next(sample_stochastic_process(Y)) in Y.state_space # doctest: +SKIP
  1406. True
  1407. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1408. 0
  1409. >>> next(sample_stochastic_process(Y)) # doctest: +SKIP
  1410. 2
  1411. Returns
  1412. =======
  1413. sample: iterator object
  1414. iterator object containing the sample of given process
  1415. """
  1416. from sympy.stats.stochastic_process_types import StochasticProcess
  1417. if not isinstance(process, StochasticProcess):
  1418. raise ValueError("Process must be an instance of Stochastic Process")
  1419. return process.sample()