123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110 |
- from sympy.core.function import Add, ArgumentIndexError, Function
- from sympy.core.power import Pow
- from sympy.core.singleton import S
- from sympy.core.sorting import default_sort_key
- from sympy.functions.elementary.exponential import exp, log
- def _logaddexp(x1, x2, *, evaluate=True):
- return log(Add(exp(x1, evaluate=evaluate), exp(x2, evaluate=evaluate), evaluate=evaluate))
- _two = S.One*2
- _ln2 = log(_two)
- def _lb(x, *, evaluate=True):
- return log(x, evaluate=evaluate)/_ln2
- def _exp2(x, *, evaluate=True):
- return Pow(_two, x, evaluate=evaluate)
- def _logaddexp2(x1, x2, *, evaluate=True):
- return _lb(Add(_exp2(x1, evaluate=evaluate),
- _exp2(x2, evaluate=evaluate), evaluate=evaluate))
- class logaddexp(Function):
- """ Logarithm of the sum of exponentiations of the inputs.
- Helper class for use with e.g. numpy.logaddexp
- See Also
- ========
- https://numpy.org/doc/stable/reference/generated/numpy.logaddexp.html
- """
- nargs = 2
- def __new__(cls, *args):
- return Function.__new__(cls, *sorted(args, key=default_sort_key))
- def fdiff(self, argindex=1):
- """
- Returns the first derivative of this function.
- """
- if argindex == 1:
- wrt, other = self.args
- elif argindex == 2:
- other, wrt = self.args
- else:
- raise ArgumentIndexError(self, argindex)
- return S.One/(S.One + exp(other-wrt))
- def _eval_rewrite_as_log(self, x1, x2, **kwargs):
- return _logaddexp(x1, x2)
- def _eval_evalf(self, *args, **kwargs):
- return self.rewrite(log).evalf(*args, **kwargs)
- def _eval_simplify(self, *args, **kwargs):
- a, b = map(lambda x: x.simplify(**kwargs), self.args)
- candidate = _logaddexp(a, b)
- if candidate != _logaddexp(a, b, evaluate=False):
- return candidate
- else:
- return logaddexp(a, b)
- class logaddexp2(Function):
- """ Logarithm of the sum of exponentiations of the inputs in base-2.
- Helper class for use with e.g. numpy.logaddexp2
- See Also
- ========
- https://numpy.org/doc/stable/reference/generated/numpy.logaddexp2.html
- """
- nargs = 2
- def __new__(cls, *args):
- return Function.__new__(cls, *sorted(args, key=default_sort_key))
- def fdiff(self, argindex=1):
- """
- Returns the first derivative of this function.
- """
- if argindex == 1:
- wrt, other = self.args
- elif argindex == 2:
- other, wrt = self.args
- else:
- raise ArgumentIndexError(self, argindex)
- return S.One/(S.One + _exp2(other-wrt))
- def _eval_rewrite_as_log(self, x1, x2, **kwargs):
- return _logaddexp2(x1, x2)
- def _eval_evalf(self, *args, **kwargs):
- return self.rewrite(log).evalf(*args, **kwargs)
- def _eval_simplify(self, *args, **kwargs):
- a, b = map(lambda x: x.simplify(**kwargs).factor(), self.args)
- candidate = _logaddexp2(a, b)
- if candidate != _logaddexp2(a, b, evaluate=False):
- return candidate
- else:
- return logaddexp2(a, b)
|