numpy_nodes.py 3.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110
  1. from sympy.core.function import Add, ArgumentIndexError, Function
  2. from sympy.core.power import Pow
  3. from sympy.core.singleton import S
  4. from sympy.core.sorting import default_sort_key
  5. from sympy.functions.elementary.exponential import exp, log
  6. def _logaddexp(x1, x2, *, evaluate=True):
  7. return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate))
  8. _two = S.One*2
  9. _ln2 = log(_two)
  10. def _lb(x, *, evaluate=True):
  11. return log(x, evaluate=evaluate)/_ln2
  12. def _exp2(x, *, evaluate=True):
  13. return Pow(_two, x, evaluate=evaluate)
  14. def _logaddexp2(x1, x2, *, evaluate=True):
  15. return _lb(Add(_exp2(x1, evaluate=evaluate),
  16. _exp2(x2, evaluate=evaluate), evaluate=evaluate))
  17. class logaddexp(Function):
  18. """ Logarithm of the sum of exponentiations of the inputs.
  19. Helper class for use with e.g. numpy.logaddexp
  20. See Also
  21. ========
  22. https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html
  23. """
  24. nargs = 2
  25. def __new__(cls, *args):
  26. return Function.__new__(cls, *sorted(args, key=default_sort_key))
  27. def fdiff(self, argindex=1):
  28. """
  29. Returns the first derivative of this function.
  30. """
  31. if argindex == 1:
  32. wrt, other = self.args
  33. elif argindex == 2:
  34. other, wrt = self.args
  35. else:
  36. raise ArgumentIndexError(self, argindex)
  37. return S.One/(S.One + exp(other-wrt))
  38. def _eval_rewrite_as_log(self, x1, x2, **kwargs):
  39. return _logaddexp(x1, x2)
  40. def _eval_evalf(self, *args, **kwargs):
  41. return self.rewrite(log).evalf(*args, **kwargs)
  42. def _eval_simplify(self, *args, **kwargs):
  43. a, b = map(lambda x: x.simplify(**kwargs), self.args)
  44. candidate = _logaddexp(a, b)
  45. if candidate != _logaddexp(a, b, evaluate=False):
  46. return candidate
  47. else:
  48. return logaddexp(a, b)
  49. class logaddexp2(Function):
  50. """ Logarithm of the sum of exponentiations of the inputs in base-2.
  51. Helper class for use with e.g. numpy.logaddexp2
  52. See Also
  53. ========
  54. https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html
  55. """
  56. nargs = 2
  57. def __new__(cls, *args):
  58. return Function.__new__(cls, *sorted(args, key=default_sort_key))
  59. def fdiff(self, argindex=1):
  60. """
  61. Returns the first derivative of this function.
  62. """
  63. if argindex == 1:
  64. wrt, other = self.args
  65. elif argindex == 2:
  66. other, wrt = self.args
  67. else:
  68. raise ArgumentIndexError(self, argindex)
  69. return S.One/(S.One + _exp2(other-wrt))
  70. def _eval_rewrite_as_log(self, x1, x2, **kwargs):
  71. return _logaddexp2(x1, x2)
  72. def _eval_evalf(self, *args, **kwargs):
  73. return self.rewrite(log).evalf(*args, **kwargs)
  74. def _eval_simplify(self, *args, **kwargs):
  75. a, b = map(lambda x: x.simplify(**kwargs).factor(), self.args)
  76. candidate = _logaddexp2(a, b)
  77. if candidate != _logaddexp2(a, b, evaluate=False):
  78. return candidate
  79. else:
  80. return logaddexp2(a, b)