transforms.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425
  1. """
  2. Discrete Fourier Transform, Number Theoretic Transform,
  3. Walsh Hadamard Transform, Mobius Transform
  4. """
  5. from sympy.core import S, Symbol, sympify
  6. from sympy.core.function import expand_mul
  7. from sympy.core.numbers import pi, I
  8. from sympy.functions.elementary.trigonometric import sin, cos
  9. from sympy.ntheory import isprime, primitive_root
  10. from sympy.utilities.iterables import ibin, iterable
  11. from sympy.utilities.misc import as_int
  12. #----------------------------------------------------------------------------#
  13. # #
  14. # Discrete Fourier Transform #
  15. # #
  16. #----------------------------------------------------------------------------#
  17. def _fourier_transform(seq, dps, inverse=False):
  18. """Utility function for the Discrete Fourier Transform"""
  19. if not iterable(seq):
  20. raise TypeError("Expected a sequence of numeric coefficients "
  21. "for Fourier Transform")
  22. a = [sympify(arg) for arg in seq]
  23. if any(x.has(Symbol) for x in a):
  24. raise ValueError("Expected non-symbolic coefficients")
  25. n = len(a)
  26. if n < 2:
  27. return a
  28. b = n.bit_length() - 1
  29. if n&(n - 1): # not a power of 2
  30. b += 1
  31. n = 2**b
  32. a += [S.Zero]*(n - len(a))
  33. for i in range(1, n):
  34. j = int(ibin(i, b, str=True)[::-1], 2)
  35. if i < j:
  36. a[i], a[j] = a[j], a[i]
  37. ang = -2*pi/n if inverse else 2*pi/n
  38. if dps is not None:
  39. ang = ang.evalf(dps + 2)
  40. w = [cos(ang*i) + I*sin(ang*i) for i in range(n // 2)]
  41. h = 2
  42. while h <= n:
  43. hf, ut = h // 2, n // h
  44. for i in range(0, n, h):
  45. for j in range(hf):
  46. u, v = a[i + j], expand_mul(a[i + j + hf]*w[ut * j])
  47. a[i + j], a[i + j + hf] = u + v, u - v
  48. h *= 2
  49. if inverse:
  50. a = [(x/n).evalf(dps) for x in a] if dps is not None \
  51. else [x/n for x in a]
  52. return a
  53. def fft(seq, dps=None):
  54. r"""
  55. Performs the Discrete Fourier Transform (**DFT**) in the complex domain.
  56. The sequence is automatically padded to the right with zeros, as the
  57. *radix-2 FFT* requires the number of sample points to be a power of 2.
  58. This method should be used with default arguments only for short sequences
  59. as the complexity of expressions increases with the size of the sequence.
  60. Parameters
  61. ==========
  62. seq : iterable
  63. The sequence on which **DFT** is to be applied.
  64. dps : Integer
  65. Specifies the number of decimal digits for precision.
  66. Examples
  67. ========
  68. >>> from sympy import fft, ifft
  69. >>> fft([1, 2, 3, 4])
  70. [10, -2 - 2*I, -2, -2 + 2*I]
  71. >>> ifft(_)
  72. [1, 2, 3, 4]
  73. >>> ifft([1, 2, 3, 4])
  74. [5/2, -1/2 + I/2, -1/2, -1/2 - I/2]
  75. >>> fft(_)
  76. [1, 2, 3, 4]
  77. >>> ifft([1, 7, 3, 4], dps=15)
  78. [3.75, -0.5 - 0.75*I, -1.75, -0.5 + 0.75*I]
  79. >>> fft(_)
  80. [1.0, 7.0, 3.0, 4.0]
  81. References
  82. ==========
  83. .. [1] https://en.wikipedia.org/wiki/Cooley%E2%80%93Tukey_FFT_algorithm
  84. .. [2] http://mathworld.wolfram.com/FastFourierTransform.html
  85. """
  86. return _fourier_transform(seq, dps=dps)
  87. def ifft(seq, dps=None):
  88. return _fourier_transform(seq, dps=dps, inverse=True)
  89. ifft.__doc__ = fft.__doc__
  90. #----------------------------------------------------------------------------#
  91. # #
  92. # Number Theoretic Transform #
  93. # #
  94. #----------------------------------------------------------------------------#
  95. def _number_theoretic_transform(seq, prime, inverse=False):
  96. """Utility function for the Number Theoretic Transform"""
  97. if not iterable(seq):
  98. raise TypeError("Expected a sequence of integer coefficients "
  99. "for Number Theoretic Transform")
  100. p = as_int(prime)
  101. if not isprime(p):
  102. raise ValueError("Expected prime modulus for "
  103. "Number Theoretic Transform")
  104. a = [as_int(x) % p for x in seq]
  105. n = len(a)
  106. if n < 1:
  107. return a
  108. b = n.bit_length() - 1
  109. if n&(n - 1):
  110. b += 1
  111. n = 2**b
  112. if (p - 1) % n:
  113. raise ValueError("Expected prime modulus of the form (m*2**k + 1)")
  114. a += [0]*(n - len(a))
  115. for i in range(1, n):
  116. j = int(ibin(i, b, str=True)[::-1], 2)
  117. if i < j:
  118. a[i], a[j] = a[j], a[i]
  119. pr = primitive_root(p)
  120. rt = pow(pr, (p - 1) // n, p)
  121. if inverse:
  122. rt = pow(rt, p - 2, p)
  123. w = [1]*(n // 2)
  124. for i in range(1, n // 2):
  125. w[i] = w[i - 1]*rt % p
  126. h = 2
  127. while h <= n:
  128. hf, ut = h // 2, n // h
  129. for i in range(0, n, h):
  130. for j in range(hf):
  131. u, v = a[i + j], a[i + j + hf]*w[ut * j]
  132. a[i + j], a[i + j + hf] = (u + v) % p, (u - v) % p
  133. h *= 2
  134. if inverse:
  135. rv = pow(n, p - 2, p)
  136. a = [x*rv % p for x in a]
  137. return a
  138. def ntt(seq, prime):
  139. r"""
  140. Performs the Number Theoretic Transform (**NTT**), which specializes the
  141. Discrete Fourier Transform (**DFT**) over quotient ring `Z/pZ` for prime
  142. `p` instead of complex numbers `C`.
  143. The sequence is automatically padded to the right with zeros, as the
  144. *radix-2 NTT* requires the number of sample points to be a power of 2.
  145. Parameters
  146. ==========
  147. seq : iterable
  148. The sequence on which **DFT** is to be applied.
  149. prime : Integer
  150. Prime modulus of the form `(m 2^k + 1)` to be used for performing
  151. **NTT** on the sequence.
  152. Examples
  153. ========
  154. >>> from sympy import ntt, intt
  155. >>> ntt([1, 2, 3, 4], prime=3*2**8 + 1)
  156. [10, 643, 767, 122]
  157. >>> intt(_, 3*2**8 + 1)
  158. [1, 2, 3, 4]
  159. >>> intt([1, 2, 3, 4], prime=3*2**8 + 1)
  160. [387, 415, 384, 353]
  161. >>> ntt(_, prime=3*2**8 + 1)
  162. [1, 2, 3, 4]
  163. References
  164. ==========
  165. .. [1] http://www.apfloat.org/ntt.html
  166. .. [2] http://mathworld.wolfram.com/NumberTheoreticTransform.html
  167. .. [3] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general%29
  168. """
  169. return _number_theoretic_transform(seq, prime=prime)
  170. def intt(seq, prime):
  171. return _number_theoretic_transform(seq, prime=prime, inverse=True)
  172. intt.__doc__ = ntt.__doc__
  173. #----------------------------------------------------------------------------#
  174. # #
  175. # Walsh Hadamard Transform #
  176. # #
  177. #----------------------------------------------------------------------------#
  178. def _walsh_hadamard_transform(seq, inverse=False):
  179. """Utility function for the Walsh Hadamard Transform"""
  180. if not iterable(seq):
  181. raise TypeError("Expected a sequence of coefficients "
  182. "for Walsh Hadamard Transform")
  183. a = [sympify(arg) for arg in seq]
  184. n = len(a)
  185. if n < 2:
  186. return a
  187. if n&(n - 1):
  188. n = 2**n.bit_length()
  189. a += [S.Zero]*(n - len(a))
  190. h = 2
  191. while h <= n:
  192. hf = h // 2
  193. for i in range(0, n, h):
  194. for j in range(hf):
  195. u, v = a[i + j], a[i + j + hf]
  196. a[i + j], a[i + j + hf] = u + v, u - v
  197. h *= 2
  198. if inverse:
  199. a = [x/n for x in a]
  200. return a
  201. def fwht(seq):
  202. r"""
  203. Performs the Walsh Hadamard Transform (**WHT**), and uses Hadamard
  204. ordering for the sequence.
  205. The sequence is automatically padded to the right with zeros, as the
  206. *radix-2 FWHT* requires the number of sample points to be a power of 2.
  207. Parameters
  208. ==========
  209. seq : iterable
  210. The sequence on which WHT is to be applied.
  211. Examples
  212. ========
  213. >>> from sympy import fwht, ifwht
  214. >>> fwht([4, 2, 2, 0, 0, 2, -2, 0])
  215. [8, 0, 8, 0, 8, 8, 0, 0]
  216. >>> ifwht(_)
  217. [4, 2, 2, 0, 0, 2, -2, 0]
  218. >>> ifwht([19, -1, 11, -9, -7, 13, -15, 5])
  219. [2, 0, 4, 0, 3, 10, 0, 0]
  220. >>> fwht(_)
  221. [19, -1, 11, -9, -7, 13, -15, 5]
  222. References
  223. ==========
  224. .. [1] https://en.wikipedia.org/wiki/Hadamard_transform
  225. .. [2] https://en.wikipedia.org/wiki/Fast_Walsh%E2%80%93Hadamard_transform
  226. """
  227. return _walsh_hadamard_transform(seq)
  228. def ifwht(seq):
  229. return _walsh_hadamard_transform(seq, inverse=True)
  230. ifwht.__doc__ = fwht.__doc__
  231. #----------------------------------------------------------------------------#
  232. # #
  233. # Mobius Transform for Subset Lattice #
  234. # #
  235. #----------------------------------------------------------------------------#
  236. def _mobius_transform(seq, sgn, subset):
  237. r"""Utility function for performing Mobius Transform using
  238. Yate's Dynamic Programming method"""
  239. if not iterable(seq):
  240. raise TypeError("Expected a sequence of coefficients")
  241. a = [sympify(arg) for arg in seq]
  242. n = len(a)
  243. if n < 2:
  244. return a
  245. if n&(n - 1):
  246. n = 2**n.bit_length()
  247. a += [S.Zero]*(n - len(a))
  248. if subset:
  249. i = 1
  250. while i < n:
  251. for j in range(n):
  252. if j & i:
  253. a[j] += sgn*a[j ^ i]
  254. i *= 2
  255. else:
  256. i = 1
  257. while i < n:
  258. for j in range(n):
  259. if j & i:
  260. continue
  261. a[j] += sgn*a[j ^ i]
  262. i *= 2
  263. return a
  264. def mobius_transform(seq, subset=True):
  265. r"""
  266. Performs the Mobius Transform for subset lattice with indices of
  267. sequence as bitmasks.
  268. The indices of each argument, considered as bit strings, correspond
  269. to subsets of a finite set.
  270. The sequence is automatically padded to the right with zeros, as the
  271. definition of subset/superset based on bitmasks (indices) requires
  272. the size of sequence to be a power of 2.
  273. Parameters
  274. ==========
  275. seq : iterable
  276. The sequence on which Mobius Transform is to be applied.
  277. subset : bool
  278. Specifies if Mobius Transform is applied by enumerating subsets
  279. or supersets of the given set.
  280. Examples
  281. ========
  282. >>> from sympy import symbols
  283. >>> from sympy import mobius_transform, inverse_mobius_transform
  284. >>> x, y, z = symbols('x y z')
  285. >>> mobius_transform([x, y, z])
  286. [x, x + y, x + z, x + y + z]
  287. >>> inverse_mobius_transform(_)
  288. [x, y, z, 0]
  289. >>> mobius_transform([x, y, z], subset=False)
  290. [x + y + z, y, z, 0]
  291. >>> inverse_mobius_transform(_, subset=False)
  292. [x, y, z, 0]
  293. >>> mobius_transform([1, 2, 3, 4])
  294. [1, 3, 4, 10]
  295. >>> inverse_mobius_transform(_)
  296. [1, 2, 3, 4]
  297. >>> mobius_transform([1, 2, 3, 4], subset=False)
  298. [10, 6, 7, 4]
  299. >>> inverse_mobius_transform(_, subset=False)
  300. [1, 2, 3, 4]
  301. References
  302. ==========
  303. .. [1] https://en.wikipedia.org/wiki/M%C3%B6bius_inversion_formula
  304. .. [2] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  305. .. [3] https://arxiv.org/pdf/1211.0189.pdf
  306. """
  307. return _mobius_transform(seq, sgn=+1, subset=subset)
  308. def inverse_mobius_transform(seq, subset=True):
  309. return _mobius_transform(seq, sgn=-1, subset=subset)
  310. inverse_mobius_transform.__doc__ = mobius_transform.__doc__