applyfunc.py 6.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204
  1. from sympy.core.expr import ExprBuilder
  2. from sympy.core.function import (Function, FunctionClass, Lambda)
  3. from sympy.core.symbol import Dummy
  4. from sympy.core.sympify import sympify, _sympify
  5. from sympy.matrices.expressions import MatrixExpr
  6. from sympy.matrices.matrices import MatrixBase
  7. class ElementwiseApplyFunction(MatrixExpr):
  8. r"""
  9. Apply function to a matrix elementwise without evaluating.
  10. Examples
  11. ========
  12. It can be created by calling ``.applyfunc(<function>)`` on a matrix
  13. expression:
  14. >>> from sympy import MatrixSymbol
  15. >>> from sympy.matrices.expressions.applyfunc import ElementwiseApplyFunction
  16. >>> from sympy import exp
  17. >>> X = MatrixSymbol("X", 3, 3)
  18. >>> X.applyfunc(exp)
  19. Lambda(_d, exp(_d)).(X)
  20. Otherwise using the class constructor:
  21. >>> from sympy import eye
  22. >>> expr = ElementwiseApplyFunction(exp, eye(3))
  23. >>> expr
  24. Lambda(_d, exp(_d)).(Matrix([
  25. [1, 0, 0],
  26. [0, 1, 0],
  27. [0, 0, 1]]))
  28. >>> expr.doit()
  29. Matrix([
  30. [E, 1, 1],
  31. [1, E, 1],
  32. [1, 1, E]])
  33. Notice the difference with the real mathematical functions:
  34. >>> exp(eye(3))
  35. Matrix([
  36. [E, 0, 0],
  37. [0, E, 0],
  38. [0, 0, E]])
  39. """
  40. def __new__(cls, function, expr):
  41. expr = _sympify(expr)
  42. if not expr.is_Matrix:
  43. raise ValueError("{} must be a matrix instance.".format(expr))
  44. if expr.shape == (1, 1):
  45. # Check if the function returns a matrix, in that case, just apply
  46. # the function instead of creating an ElementwiseApplyFunc object:
  47. ret = function(expr)
  48. if isinstance(ret, MatrixExpr):
  49. return ret
  50. if not isinstance(function, (FunctionClass, Lambda)):
  51. d = Dummy('d')
  52. function = Lambda(d, function(d))
  53. function = sympify(function)
  54. if not isinstance(function, (FunctionClass, Lambda)):
  55. raise ValueError(
  56. "{} should be compatible with SymPy function classes."
  57. .format(function))
  58. if 1 not in function.nargs:
  59. raise ValueError(
  60. '{} should be able to accept 1 arguments.'.format(function))
  61. if not isinstance(function, Lambda):
  62. d = Dummy('d')
  63. function = Lambda(d, function(d))
  64. obj = MatrixExpr.__new__(cls, function, expr)
  65. return obj
  66. @property
  67. def function(self):
  68. return self.args[0]
  69. @property
  70. def expr(self):
  71. return self.args[1]
  72. @property
  73. def shape(self):
  74. return self.expr.shape
  75. def doit(self, **kwargs):
  76. deep = kwargs.get("deep", True)
  77. expr = self.expr
  78. if deep:
  79. expr = expr.doit(**kwargs)
  80. function = self.function
  81. if isinstance(function, Lambda) and function.is_identity:
  82. # This is a Lambda containing the identity function.
  83. return expr
  84. if isinstance(expr, MatrixBase):
  85. return expr.applyfunc(self.function)
  86. elif isinstance(expr, ElementwiseApplyFunction):
  87. return ElementwiseApplyFunction(
  88. lambda x: self.function(expr.function(x)),
  89. expr.expr
  90. ).doit()
  91. else:
  92. return self
  93. def _entry(self, i, j, **kwargs):
  94. return self.function(self.expr._entry(i, j, **kwargs))
  95. def _get_function_fdiff(self):
  96. d = Dummy("d")
  97. function = self.function(d)
  98. fdiff = function.diff(d)
  99. if isinstance(fdiff, Function):
  100. fdiff = type(fdiff)
  101. else:
  102. fdiff = Lambda(d, fdiff)
  103. return fdiff
  104. def _eval_derivative(self, x):
  105. from sympy.matrices.expressions.hadamard import hadamard_product
  106. dexpr = self.expr.diff(x)
  107. fdiff = self._get_function_fdiff()
  108. return hadamard_product(
  109. dexpr,
  110. ElementwiseApplyFunction(fdiff, self.expr)
  111. )
  112. def _eval_derivative_matrix_lines(self, x):
  113. from sympy.matrices.expressions.special import Identity
  114. from sympy.tensor.array.expressions.array_expressions import ArrayContraction
  115. from sympy.tensor.array.expressions.array_expressions import ArrayDiagonal
  116. from sympy.tensor.array.expressions.array_expressions import ArrayTensorProduct
  117. fdiff = self._get_function_fdiff()
  118. lr = self.expr._eval_derivative_matrix_lines(x)
  119. ewdiff = ElementwiseApplyFunction(fdiff, self.expr)
  120. if 1 in x.shape:
  121. # Vector:
  122. iscolumn = self.shape[1] == 1
  123. for i in lr:
  124. if iscolumn:
  125. ptr1 = i.first_pointer
  126. ptr2 = Identity(self.shape[1])
  127. else:
  128. ptr1 = Identity(self.shape[0])
  129. ptr2 = i.second_pointer
  130. subexpr = ExprBuilder(
  131. ArrayDiagonal,
  132. [
  133. ExprBuilder(
  134. ArrayTensorProduct,
  135. [
  136. ewdiff,
  137. ptr1,
  138. ptr2,
  139. ]
  140. ),
  141. (0, 2) if iscolumn else (1, 4)
  142. ],
  143. validator=ArrayDiagonal._validate
  144. )
  145. i._lines = [subexpr]
  146. i._first_pointer_parent = subexpr.args[0].args
  147. i._first_pointer_index = 1
  148. i._second_pointer_parent = subexpr.args[0].args
  149. i._second_pointer_index = 2
  150. else:
  151. # Matrix case:
  152. for i in lr:
  153. ptr1 = i.first_pointer
  154. ptr2 = i.second_pointer
  155. newptr1 = Identity(ptr1.shape[1])
  156. newptr2 = Identity(ptr2.shape[1])
  157. subexpr = ExprBuilder(
  158. ArrayContraction,
  159. [
  160. ExprBuilder(
  161. ArrayTensorProduct,
  162. [ptr1, newptr1, ewdiff, ptr2, newptr2]
  163. ),
  164. (1, 2, 4),
  165. (5, 7, 8),
  166. ],
  167. validator=ArrayContraction._validate
  168. )
  169. i._first_pointer_parent = subexpr.args[0].args
  170. i._first_pointer_index = 1
  171. i._second_pointer_parent = subexpr.args[0].args
  172. i._second_pointer_index = 4
  173. i._lines = [subexpr]
  174. return lr
  175. def _eval_transpose(self):
  176. from sympy.matrices.expressions.transpose import Transpose
  177. return self.func(self.function, Transpose(self.expr).doit())