combsimp.py 3.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114
  1. from sympy.core import Mul
  2. from sympy.core.function import count_ops
  3. from sympy.core.traversal import preorder_traversal, bottom_up
  4. from sympy.functions.combinatorial.factorials import binomial, factorial
  5. from sympy.functions import gamma
  6. from sympy.simplify.gammasimp import gammasimp, _gammasimp
  7. from sympy.utilities.timeutils import timethis
  8. @timethis('combsimp')
  9. def combsimp(expr):
  10. r"""
  11. Simplify combinatorial expressions.
  12. Explanation
  13. ===========
  14. This function takes as input an expression containing factorials,
  15. binomials, Pochhammer symbol and other "combinatorial" functions,
  16. and tries to minimize the number of those functions and reduce
  17. the size of their arguments.
  18. The algorithm works by rewriting all combinatorial functions as
  19. gamma functions and applying gammasimp() except simplification
  20. steps that may make an integer argument non-integer. See docstring
  21. of gammasimp for more information.
  22. Then it rewrites expression in terms of factorials and binomials by
  23. rewriting gammas as factorials and converting (a+b)!/a!b! into
  24. binomials.
  25. If expression has gamma functions or combinatorial functions
  26. with non-integer argument, it is automatically passed to gammasimp.
  27. Examples
  28. ========
  29. >>> from sympy.simplify import combsimp
  30. >>> from sympy import factorial, binomial, symbols
  31. >>> n, k = symbols('n k', integer = True)
  32. >>> combsimp(factorial(n)/factorial(n - 3))
  33. n*(n - 2)*(n - 1)
  34. >>> combsimp(binomial(n+1, k+1)/binomial(n, k))
  35. (n + 1)/(k + 1)
  36. """
  37. expr = expr.rewrite(gamma, piecewise=False)
  38. if any(isinstance(node, gamma) and not node.args[0].is_integer
  39. for node in preorder_traversal(expr)):
  40. return gammasimp(expr);
  41. expr = _gammasimp(expr, as_comb = True)
  42. expr = _gamma_as_comb(expr)
  43. return expr
  44. def _gamma_as_comb(expr):
  45. """
  46. Helper function for combsimp.
  47. Rewrites expression in terms of factorials and binomials
  48. """
  49. expr = expr.rewrite(factorial)
  50. def f(rv):
  51. if not rv.is_Mul:
  52. return rv
  53. rvd = rv.as_powers_dict()
  54. nd_fact_args = [[], []] # numerator, denominator
  55. for k in rvd:
  56. if isinstance(k, factorial) and rvd[k].is_Integer:
  57. if rvd[k].is_positive:
  58. nd_fact_args[0].extend([k.args[0]]*rvd[k])
  59. else:
  60. nd_fact_args[1].extend([k.args[0]]*-rvd[k])
  61. rvd[k] = 0
  62. if not nd_fact_args[0] or not nd_fact_args[1]:
  63. return rv
  64. hit = False
  65. for m in range(2):
  66. i = 0
  67. while i < len(nd_fact_args[m]):
  68. ai = nd_fact_args[m][i]
  69. for j in range(i + 1, len(nd_fact_args[m])):
  70. aj = nd_fact_args[m][j]
  71. sum = ai + aj
  72. if sum in nd_fact_args[1 - m]:
  73. hit = True
  74. nd_fact_args[1 - m].remove(sum)
  75. del nd_fact_args[m][j]
  76. del nd_fact_args[m][i]
  77. rvd[binomial(sum, ai if count_ops(ai) <
  78. count_ops(aj) else aj)] += (
  79. -1 if m == 0 else 1)
  80. break
  81. else:
  82. i += 1
  83. if hit:
  84. return Mul(*([k**rvd[k] for k in rvd] + [factorial(k)
  85. for k in nd_fact_args[0]]))/Mul(*[factorial(k)
  86. for k in nd_fact_args[1]])
  87. return rv
  88. return bottom_up(expr, f)