fourier.py 2.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091
  1. from sympy.core.sympify import _sympify
  2. from sympy.matrices.expressions import MatrixExpr
  3. from sympy.core.numbers import I
  4. from sympy.core.singleton import S
  5. from sympy.functions.elementary.exponential import exp
  6. from sympy.functions.elementary.miscellaneous import sqrt
  7. class DFT(MatrixExpr):
  8. r"""
  9. Returns a discrete Fourier transform matrix. The matrix is scaled
  10. with :math:`\frac{1}{\sqrt{n}}` so that it is unitary.
  11. Parameters
  12. ==========
  13. n : integer or Symbol
  14. Size of the transform.
  15. Examples
  16. ========
  17. >>> from sympy.abc import n
  18. >>> from sympy.matrices.expressions.fourier import DFT
  19. >>> DFT(3)
  20. DFT(3)
  21. >>> DFT(3).as_explicit()
  22. Matrix([
  23. [sqrt(3)/3, sqrt(3)/3, sqrt(3)/3],
  24. [sqrt(3)/3, sqrt(3)*exp(-2*I*pi/3)/3, sqrt(3)*exp(2*I*pi/3)/3],
  25. [sqrt(3)/3, sqrt(3)*exp(2*I*pi/3)/3, sqrt(3)*exp(-2*I*pi/3)/3]])
  26. >>> DFT(n).shape
  27. (n, n)
  28. References
  29. ==========
  30. .. [1] https://en.wikipedia.org/wiki/DFT_matrix
  31. """
  32. def __new__(cls, n):
  33. n = _sympify(n)
  34. cls._check_dim(n)
  35. obj = super().__new__(cls, n)
  36. return obj
  37. n = property(lambda self: self.args[0]) # type: ignore
  38. shape = property(lambda self: (self.n, self.n)) # type: ignore
  39. def _entry(self, i, j, **kwargs):
  40. w = exp(-2*S.Pi*I/self.n)
  41. return w**(i*j) / sqrt(self.n)
  42. def _eval_inverse(self):
  43. return IDFT(self.n)
  44. class IDFT(DFT):
  45. r"""
  46. Returns an inverse discrete Fourier transform matrix. The matrix is scaled
  47. with :math:`\frac{1}{\sqrt{n}}` so that it is unitary.
  48. Parameters
  49. ==========
  50. n : integer or Symbol
  51. Size of the transform
  52. Examples
  53. ========
  54. >>> from sympy.matrices.expressions.fourier import DFT, IDFT
  55. >>> IDFT(3)
  56. IDFT(3)
  57. >>> IDFT(4)*DFT(4)
  58. I
  59. See Also
  60. ========
  61. DFT
  62. """
  63. def _entry(self, i, j, **kwargs):
  64. w = exp(-2*S.Pi*I/self.n)
  65. return w**(-i*j) / sqrt(self.n)
  66. def _eval_inverse(self):
  67. return DFT(self.n)