convolutions.py 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488
  1. """
  2. Convolution (using **FFT**, **NTT**, **FWHT**), Subset Convolution,
  3. Covering Product, Intersecting Product
  4. """
  5. from sympy.core import S, sympify
  6. from sympy.core.function import expand_mul
  7. from sympy.discrete.transforms import (
  8. fft, ifft, ntt, intt, fwht, ifwht,
  9. mobius_transform, inverse_mobius_transform)
  10. from sympy.utilities.iterables import iterable
  11. from sympy.utilities.misc import as_int
  12. def convolution(a, b, cycle=0, dps=None, prime=None, dyadic=None, subset=None):
  13. """
  14. Performs convolution by determining the type of desired
  15. convolution using hints.
  16. Exactly one of ``dps``, ``prime``, ``dyadic``, ``subset`` arguments
  17. should be specified explicitly for identifying the type of convolution,
  18. and the argument ``cycle`` can be specified optionally.
  19. For the default arguments, linear convolution is performed using **FFT**.
  20. Parameters
  21. ==========
  22. a, b : iterables
  23. The sequences for which convolution is performed.
  24. cycle : Integer
  25. Specifies the length for doing cyclic convolution.
  26. dps : Integer
  27. Specifies the number of decimal digits for precision for
  28. performing **FFT** on the sequence.
  29. prime : Integer
  30. Prime modulus of the form `(m 2^k + 1)` to be used for
  31. performing **NTT** on the sequence.
  32. dyadic : bool
  33. Identifies the convolution type as dyadic (*bitwise-XOR*)
  34. convolution, which is performed using **FWHT**.
  35. subset : bool
  36. Identifies the convolution type as subset convolution.
  37. Examples
  38. ========
  39. >>> from sympy import convolution, symbols, S, I
  40. >>> u, v, w, x, y, z = symbols('u v w x y z')
  41. >>> convolution([1 + 2*I, 4 + 3*I], [S(5)/4, 6], dps=3)
  42. [1.25 + 2.5*I, 11.0 + 15.8*I, 24.0 + 18.0*I]
  43. >>> convolution([1, 2, 3], [4, 5, 6], cycle=3)
  44. [31, 31, 28]
  45. >>> convolution([111, 777], [888, 444], prime=19*2**10 + 1)
  46. [1283, 19351, 14219]
  47. >>> convolution([111, 777], [888, 444], prime=19*2**10 + 1, cycle=2)
  48. [15502, 19351]
  49. >>> convolution([u, v], [x, y, z], dyadic=True)
  50. [u*x + v*y, u*y + v*x, u*z, v*z]
  51. >>> convolution([u, v], [x, y, z], dyadic=True, cycle=2)
  52. [u*x + u*z + v*y, u*y + v*x + v*z]
  53. >>> convolution([u, v, w], [x, y, z], subset=True)
  54. [u*x, u*y + v*x, u*z + w*x, v*z + w*y]
  55. >>> convolution([u, v, w], [x, y, z], subset=True, cycle=3)
  56. [u*x + v*z + w*y, u*y + v*x, u*z + w*x]
  57. """
  58. c = as_int(cycle)
  59. if c < 0:
  60. raise ValueError("The length for cyclic convolution "
  61. "must be non-negative")
  62. dyadic = True if dyadic else None
  63. subset = True if subset else None
  64. if sum(x is not None for x in (prime, dps, dyadic, subset)) > 1:
  65. raise TypeError("Ambiguity in determining the type of convolution")
  66. if prime is not None:
  67. ls = convolution_ntt(a, b, prime=prime)
  68. return ls if not c else [sum(ls[i::c]) % prime for i in range(c)]
  69. if dyadic:
  70. ls = convolution_fwht(a, b)
  71. elif subset:
  72. ls = convolution_subset(a, b)
  73. else:
  74. ls = convolution_fft(a, b, dps=dps)
  75. return ls if not c else [sum(ls[i::c]) for i in range(c)]
  76. #----------------------------------------------------------------------------#
  77. # #
  78. # Convolution for Complex domain #
  79. # #
  80. #----------------------------------------------------------------------------#
  81. def convolution_fft(a, b, dps=None):
  82. """
  83. Performs linear convolution using Fast Fourier Transform.
  84. Parameters
  85. ==========
  86. a, b : iterables
  87. The sequences for which convolution is performed.
  88. dps : Integer
  89. Specifies the number of decimal digits for precision.
  90. Examples
  91. ========
  92. >>> from sympy import S, I
  93. >>> from sympy.discrete.convolutions import convolution_fft
  94. >>> convolution_fft([2, 3], [4, 5])
  95. [8, 22, 15]
  96. >>> convolution_fft([2, 5], [6, 7, 3])
  97. [12, 44, 41, 15]
  98. >>> convolution_fft([1 + 2*I, 4 + 3*I], [S(5)/4, 6])
  99. [5/4 + 5*I/2, 11 + 63*I/4, 24 + 18*I]
  100. References
  101. ==========
  102. .. [1] https://en.wikipedia.org/wiki/Convolution_theorem
  103. .. [2] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general%29
  104. """
  105. a, b = a[:], b[:]
  106. n = m = len(a) + len(b) - 1 # convolution size
  107. if n > 0 and n&(n - 1): # not a power of 2
  108. n = 2**n.bit_length()
  109. # padding with zeros
  110. a += [S.Zero]*(n - len(a))
  111. b += [S.Zero]*(n - len(b))
  112. a, b = fft(a, dps), fft(b, dps)
  113. a = [expand_mul(x*y) for x, y in zip(a, b)]
  114. a = ifft(a, dps)[:m]
  115. return a
  116. #----------------------------------------------------------------------------#
  117. # #
  118. # Convolution for GF(p) #
  119. # #
  120. #----------------------------------------------------------------------------#
  121. def convolution_ntt(a, b, prime):
  122. """
  123. Performs linear convolution using Number Theoretic Transform.
  124. Parameters
  125. ==========
  126. a, b : iterables
  127. The sequences for which convolution is performed.
  128. prime : Integer
  129. Prime modulus of the form `(m 2^k + 1)` to be used for performing
  130. **NTT** on the sequence.
  131. Examples
  132. ========
  133. >>> from sympy.discrete.convolutions import convolution_ntt
  134. >>> convolution_ntt([2, 3], [4, 5], prime=19*2**10 + 1)
  135. [8, 22, 15]
  136. >>> convolution_ntt([2, 5], [6, 7, 3], prime=19*2**10 + 1)
  137. [12, 44, 41, 15]
  138. >>> convolution_ntt([333, 555], [222, 666], prime=19*2**10 + 1)
  139. [15555, 14219, 19404]
  140. References
  141. ==========
  142. .. [1] https://en.wikipedia.org/wiki/Convolution_theorem
  143. .. [2] https://en.wikipedia.org/wiki/Discrete_Fourier_transform_(general%29
  144. """
  145. a, b, p = a[:], b[:], as_int(prime)
  146. n = m = len(a) + len(b) - 1 # convolution size
  147. if n > 0 and n&(n - 1): # not a power of 2
  148. n = 2**n.bit_length()
  149. # padding with zeros
  150. a += [0]*(n - len(a))
  151. b += [0]*(n - len(b))
  152. a, b = ntt(a, p), ntt(b, p)
  153. a = [x*y % p for x, y in zip(a, b)]
  154. a = intt(a, p)[:m]
  155. return a
  156. #----------------------------------------------------------------------------#
  157. # #
  158. # Convolution for 2**n-group #
  159. # #
  160. #----------------------------------------------------------------------------#
  161. def convolution_fwht(a, b):
  162. """
  163. Performs dyadic (*bitwise-XOR*) convolution using Fast Walsh Hadamard
  164. Transform.
  165. The convolution is automatically padded to the right with zeros, as the
  166. *radix-2 FWHT* requires the number of sample points to be a power of 2.
  167. Parameters
  168. ==========
  169. a, b : iterables
  170. The sequences for which convolution is performed.
  171. Examples
  172. ========
  173. >>> from sympy import symbols, S, I
  174. >>> from sympy.discrete.convolutions import convolution_fwht
  175. >>> u, v, x, y = symbols('u v x y')
  176. >>> convolution_fwht([u, v], [x, y])
  177. [u*x + v*y, u*y + v*x]
  178. >>> convolution_fwht([2, 3], [4, 5])
  179. [23, 22]
  180. >>> convolution_fwht([2, 5 + 4*I, 7], [6*I, 7, 3 + 4*I])
  181. [56 + 68*I, -10 + 30*I, 6 + 50*I, 48 + 32*I]
  182. >>> convolution_fwht([S(33)/7, S(55)/6, S(7)/4], [S(2)/3, 5])
  183. [2057/42, 1870/63, 7/6, 35/4]
  184. References
  185. ==========
  186. .. [1] https://www.radioeng.cz/fulltexts/2002/02_03_40_42.pdf
  187. .. [2] https://en.wikipedia.org/wiki/Hadamard_transform
  188. """
  189. if not a or not b:
  190. return []
  191. a, b = a[:], b[:]
  192. n = max(len(a), len(b))
  193. if n&(n - 1): # not a power of 2
  194. n = 2**n.bit_length()
  195. # padding with zeros
  196. a += [S.Zero]*(n - len(a))
  197. b += [S.Zero]*(n - len(b))
  198. a, b = fwht(a), fwht(b)
  199. a = [expand_mul(x*y) for x, y in zip(a, b)]
  200. a = ifwht(a)
  201. return a
  202. #----------------------------------------------------------------------------#
  203. # #
  204. # Subset Convolution #
  205. # #
  206. #----------------------------------------------------------------------------#
  207. def convolution_subset(a, b):
  208. """
  209. Performs Subset Convolution of given sequences.
  210. The indices of each argument, considered as bit strings, correspond to
  211. subsets of a finite set.
  212. The sequence is automatically padded to the right with zeros, as the
  213. definition of subset based on bitmasks (indices) requires the size of
  214. sequence to be a power of 2.
  215. Parameters
  216. ==========
  217. a, b : iterables
  218. The sequences for which convolution is performed.
  219. Examples
  220. ========
  221. >>> from sympy import symbols, S
  222. >>> from sympy.discrete.convolutions import convolution_subset
  223. >>> u, v, x, y, z = symbols('u v x y z')
  224. >>> convolution_subset([u, v], [x, y])
  225. [u*x, u*y + v*x]
  226. >>> convolution_subset([u, v, x], [y, z])
  227. [u*y, u*z + v*y, x*y, x*z]
  228. >>> convolution_subset([1, S(2)/3], [3, 4])
  229. [3, 6]
  230. >>> convolution_subset([1, 3, S(5)/7], [7])
  231. [7, 21, 5, 0]
  232. References
  233. ==========
  234. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  235. """
  236. if not a or not b:
  237. return []
  238. if not iterable(a) or not iterable(b):
  239. raise TypeError("Expected a sequence of coefficients for convolution")
  240. a = [sympify(arg) for arg in a]
  241. b = [sympify(arg) for arg in b]
  242. n = max(len(a), len(b))
  243. if n&(n - 1): # not a power of 2
  244. n = 2**n.bit_length()
  245. # padding with zeros
  246. a += [S.Zero]*(n - len(a))
  247. b += [S.Zero]*(n - len(b))
  248. c = [S.Zero]*n
  249. for mask in range(n):
  250. smask = mask
  251. while smask > 0:
  252. c[mask] += expand_mul(a[smask] * b[mask^smask])
  253. smask = (smask - 1)&mask
  254. c[mask] += expand_mul(a[smask] * b[mask^smask])
  255. return c
  256. #----------------------------------------------------------------------------#
  257. # #
  258. # Covering Product #
  259. # #
  260. #----------------------------------------------------------------------------#
  261. def covering_product(a, b):
  262. """
  263. Returns the covering product of given sequences.
  264. The indices of each argument, considered as bit strings, correspond to
  265. subsets of a finite set.
  266. The covering product of given sequences is a sequence which contains
  267. the sum of products of the elements of the given sequences grouped by
  268. the *bitwise-OR* of the corresponding indices.
  269. The sequence is automatically padded to the right with zeros, as the
  270. definition of subset based on bitmasks (indices) requires the size of
  271. sequence to be a power of 2.
  272. Parameters
  273. ==========
  274. a, b : iterables
  275. The sequences for which covering product is to be obtained.
  276. Examples
  277. ========
  278. >>> from sympy import symbols, S, I, covering_product
  279. >>> u, v, x, y, z = symbols('u v x y z')
  280. >>> covering_product([u, v], [x, y])
  281. [u*x, u*y + v*x + v*y]
  282. >>> covering_product([u, v, x], [y, z])
  283. [u*y, u*z + v*y + v*z, x*y, x*z]
  284. >>> covering_product([1, S(2)/3], [3, 4 + 5*I])
  285. [3, 26/3 + 25*I/3]
  286. >>> covering_product([1, 3, S(5)/7], [7, 8])
  287. [7, 53, 5, 40/7]
  288. References
  289. ==========
  290. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  291. """
  292. if not a or not b:
  293. return []
  294. a, b = a[:], b[:]
  295. n = max(len(a), len(b))
  296. if n&(n - 1): # not a power of 2
  297. n = 2**n.bit_length()
  298. # padding with zeros
  299. a += [S.Zero]*(n - len(a))
  300. b += [S.Zero]*(n - len(b))
  301. a, b = mobius_transform(a), mobius_transform(b)
  302. a = [expand_mul(x*y) for x, y in zip(a, b)]
  303. a = inverse_mobius_transform(a)
  304. return a
  305. #----------------------------------------------------------------------------#
  306. # #
  307. # Intersecting Product #
  308. # #
  309. #----------------------------------------------------------------------------#
  310. def intersecting_product(a, b):
  311. """
  312. Returns the intersecting product of given sequences.
  313. The indices of each argument, considered as bit strings, correspond to
  314. subsets of a finite set.
  315. The intersecting product of given sequences is the sequence which
  316. contains the sum of products of the elements of the given sequences
  317. grouped by the *bitwise-AND* of the corresponding indices.
  318. The sequence is automatically padded to the right with zeros, as the
  319. definition of subset based on bitmasks (indices) requires the size of
  320. sequence to be a power of 2.
  321. Parameters
  322. ==========
  323. a, b : iterables
  324. The sequences for which intersecting product is to be obtained.
  325. Examples
  326. ========
  327. >>> from sympy import symbols, S, I, intersecting_product
  328. >>> u, v, x, y, z = symbols('u v x y z')
  329. >>> intersecting_product([u, v], [x, y])
  330. [u*x + u*y + v*x, v*y]
  331. >>> intersecting_product([u, v, x], [y, z])
  332. [u*y + u*z + v*y + x*y + x*z, v*z, 0, 0]
  333. >>> intersecting_product([1, S(2)/3], [3, 4 + 5*I])
  334. [9 + 5*I, 8/3 + 10*I/3]
  335. >>> intersecting_product([1, 3, S(5)/7], [7, 8])
  336. [327/7, 24, 0, 0]
  337. References
  338. ==========
  339. .. [1] https://people.csail.mit.edu/rrw/presentations/subset-conv.pdf
  340. """
  341. if not a or not b:
  342. return []
  343. a, b = a[:], b[:]
  344. n = max(len(a), len(b))
  345. if n&(n - 1): # not a power of 2
  346. n = 2**n.bit_length()
  347. # padding with zeros
  348. a += [S.Zero]*(n - len(a))
  349. b += [S.Zero]*(n - len(b))
  350. a, b = mobius_transform(a, subset=False), mobius_transform(b, subset=False)
  351. a = [expand_mul(x*y) for x, y in zip(a, b)]
  352. a = inverse_mobius_transform(a, subset=False)
  353. return a