bsplines.py 9.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351
  1. from sympy.core import S, sympify
  2. from sympy.core.symbol import (Dummy, symbols)
  3. from sympy.functions import Piecewise, piecewise_fold
  4. from sympy.sets.sets import Interval
  5. from functools import lru_cache
  6. def _ivl(cond, x):
  7. """return the interval corresponding to the condition
  8. Conditions in spline's Piecewise give the range over
  9. which an expression is valid like (lo <= x) & (x <= hi).
  10. This function returns (lo, hi).
  11. """
  12. from sympy.logic.boolalg import And
  13. if isinstance(cond, And) and len(cond.args) == 2:
  14. a, b = cond.args
  15. if a.lts == x:
  16. a, b = b, a
  17. return a.lts, b.gts
  18. raise TypeError('unexpected cond type: %s' % cond)
  19. def _add_splines(c, b1, d, b2, x):
  20. """Construct c*b1 + d*b2."""
  21. if S.Zero in (b1, c):
  22. rv = piecewise_fold(d * b2)
  23. elif S.Zero in (b2, d):
  24. rv = piecewise_fold(c * b1)
  25. else:
  26. new_args = []
  27. # Just combining the Piecewise without any fancy optimization
  28. p1 = piecewise_fold(c * b1)
  29. p2 = piecewise_fold(d * b2)
  30. # Search all Piecewise arguments except (0, True)
  31. p2args = list(p2.args[:-1])
  32. # This merging algorithm assumes the conditions in
  33. # p1 and p2 are sorted
  34. for arg in p1.args[:-1]:
  35. expr = arg.expr
  36. cond = arg.cond
  37. lower = _ivl(cond, x)[0]
  38. # Check p2 for matching conditions that can be merged
  39. for i, arg2 in enumerate(p2args):
  40. expr2 = arg2.expr
  41. cond2 = arg2.cond
  42. lower_2, upper_2 = _ivl(cond2, x)
  43. if cond2 == cond:
  44. # Conditions match, join expressions
  45. expr += expr2
  46. # Remove matching element
  47. del p2args[i]
  48. # No need to check the rest
  49. break
  50. elif lower_2 < lower and upper_2 <= lower:
  51. # Check if arg2 condition smaller than arg1,
  52. # add to new_args by itself (no match expected
  53. # in p1)
  54. new_args.append(arg2)
  55. del p2args[i]
  56. break
  57. # Checked all, add expr and cond
  58. new_args.append((expr, cond))
  59. # Add remaining items from p2args
  60. new_args.extend(p2args)
  61. # Add final (0, True)
  62. new_args.append((0, True))
  63. rv = Piecewise(*new_args, evaluate=False)
  64. return rv.expand()
  65. @lru_cache(maxsize=128)
  66. def bspline_basis(d, knots, n, x):
  67. """
  68. The $n$-th B-spline at $x$ of degree $d$ with knots.
  69. Explanation
  70. ===========
  71. B-Splines are piecewise polynomials of degree $d$. They are defined on a
  72. set of knots, which is a sequence of integers or floats.
  73. Examples
  74. ========
  75. The 0th degree splines have a value of 1 on a single interval:
  76. >>> from sympy import bspline_basis
  77. >>> from sympy.abc import x
  78. >>> d = 0
  79. >>> knots = tuple(range(5))
  80. >>> bspline_basis(d, knots, 0, x)
  81. Piecewise((1, (x >= 0) & (x <= 1)), (0, True))
  82. For a given ``(d, knots)`` there are ``len(knots)-d-1`` B-splines
  83. defined, that are indexed by ``n`` (starting at 0).
  84. Here is an example of a cubic B-spline:
  85. >>> bspline_basis(3, tuple(range(5)), 0, x)
  86. Piecewise((x**3/6, (x >= 0) & (x <= 1)),
  87. (-x**3/2 + 2*x**2 - 2*x + 2/3,
  88. (x >= 1) & (x <= 2)),
  89. (x**3/2 - 4*x**2 + 10*x - 22/3,
  90. (x >= 2) & (x <= 3)),
  91. (-x**3/6 + 2*x**2 - 8*x + 32/3,
  92. (x >= 3) & (x <= 4)),
  93. (0, True))
  94. By repeating knot points, you can introduce discontinuities in the
  95. B-splines and their derivatives:
  96. >>> d = 1
  97. >>> knots = (0, 0, 2, 3, 4)
  98. >>> bspline_basis(d, knots, 0, x)
  99. Piecewise((1 - x/2, (x >= 0) & (x <= 2)), (0, True))
  100. It is quite time consuming to construct and evaluate B-splines. If
  101. you need to evaluate a B-spline many times, it is best to lambdify them
  102. first:
  103. >>> from sympy import lambdify
  104. >>> d = 3
  105. >>> knots = tuple(range(10))
  106. >>> b0 = bspline_basis(d, knots, 0, x)
  107. >>> f = lambdify(x, b0)
  108. >>> y = f(0.5)
  109. Parameters
  110. ==========
  111. d : integer
  112. degree of bspline
  113. knots : list of integer values
  114. list of knots points of bspline
  115. n : integer
  116. $n$-th B-spline
  117. x : symbol
  118. See Also
  119. ========
  120. bspline_basis_set
  121. References
  122. ==========
  123. .. [1] https://en.wikipedia.org/wiki/B-spline
  124. """
  125. # make sure x has no assumptions so conditions don't evaluate
  126. xvar = x
  127. x = Dummy()
  128. knots = tuple(sympify(k) for k in knots)
  129. d = int(d)
  130. n = int(n)
  131. n_knots = len(knots)
  132. n_intervals = n_knots - 1
  133. if n + d + 1 > n_intervals:
  134. raise ValueError("n + d + 1 must not exceed len(knots) - 1")
  135. if d == 0:
  136. result = Piecewise(
  137. (S.One, Interval(knots[n], knots[n + 1]).contains(x)), (0, True)
  138. )
  139. elif d > 0:
  140. denom = knots[n + d + 1] - knots[n + 1]
  141. if denom != S.Zero:
  142. B = (knots[n + d + 1] - x) / denom
  143. b2 = bspline_basis(d - 1, knots, n + 1, x)
  144. else:
  145. b2 = B = S.Zero
  146. denom = knots[n + d] - knots[n]
  147. if denom != S.Zero:
  148. A = (x - knots[n]) / denom
  149. b1 = bspline_basis(d - 1, knots, n, x)
  150. else:
  151. b1 = A = S.Zero
  152. result = _add_splines(A, b1, B, b2, x)
  153. else:
  154. raise ValueError("degree must be non-negative: %r" % n)
  155. # return result with user-given x
  156. return result.xreplace({x: xvar})
  157. def bspline_basis_set(d, knots, x):
  158. """
  159. Return the ``len(knots)-d-1`` B-splines at *x* of degree *d*
  160. with *knots*.
  161. Explanation
  162. ===========
  163. This function returns a list of piecewise polynomials that are the
  164. ``len(knots)-d-1`` B-splines of degree *d* for the given knots.
  165. This function calls ``bspline_basis(d, knots, n, x)`` for different
  166. values of *n*.
  167. Examples
  168. ========
  169. >>> from sympy import bspline_basis_set
  170. >>> from sympy.abc import x
  171. >>> d = 2
  172. >>> knots = range(5)
  173. >>> splines = bspline_basis_set(d, knots, x)
  174. >>> splines
  175. [Piecewise((x**2/2, (x >= 0) & (x <= 1)),
  176. (-x**2 + 3*x - 3/2, (x >= 1) & (x <= 2)),
  177. (x**2/2 - 3*x + 9/2, (x >= 2) & (x <= 3)),
  178. (0, True)),
  179. Piecewise((x**2/2 - x + 1/2, (x >= 1) & (x <= 2)),
  180. (-x**2 + 5*x - 11/2, (x >= 2) & (x <= 3)),
  181. (x**2/2 - 4*x + 8, (x >= 3) & (x <= 4)),
  182. (0, True))]
  183. Parameters
  184. ==========
  185. d : integer
  186. degree of bspline
  187. knots : list of integers
  188. list of knots points of bspline
  189. x : symbol
  190. See Also
  191. ========
  192. bspline_basis
  193. """
  194. n_splines = len(knots) - d - 1
  195. return [bspline_basis(d, tuple(knots), i, x) for i in range(n_splines)]
  196. def interpolating_spline(d, x, X, Y):
  197. """
  198. Return spline of degree *d*, passing through the given *X*
  199. and *Y* values.
  200. Explanation
  201. ===========
  202. This function returns a piecewise function such that each part is
  203. a polynomial of degree not greater than *d*. The value of *d*
  204. must be 1 or greater and the values of *X* must be strictly
  205. increasing.
  206. Examples
  207. ========
  208. >>> from sympy import interpolating_spline
  209. >>> from sympy.abc import x
  210. >>> interpolating_spline(1, x, [1, 2, 4, 7], [3, 6, 5, 7])
  211. Piecewise((3*x, (x >= 1) & (x <= 2)),
  212. (7 - x/2, (x >= 2) & (x <= 4)),
  213. (2*x/3 + 7/3, (x >= 4) & (x <= 7)))
  214. >>> interpolating_spline(3, x, [-2, 0, 1, 3, 4], [4, 2, 1, 1, 3])
  215. Piecewise((7*x**3/117 + 7*x**2/117 - 131*x/117 + 2, (x >= -2) & (x <= 1)),
  216. (10*x**3/117 - 2*x**2/117 - 122*x/117 + 77/39, (x >= 1) & (x <= 4)))
  217. Parameters
  218. ==========
  219. d : integer
  220. Degree of Bspline strictly greater than equal to one
  221. x : symbol
  222. X : list of strictly increasing integer values
  223. list of X coordinates through which the spline passes
  224. Y : list of strictly increasing integer values
  225. list of Y coordinates through which the spline passes
  226. See Also
  227. ========
  228. bspline_basis_set, interpolating_poly
  229. """
  230. from sympy.solvers.solveset import linsolve
  231. from sympy.matrices.dense import Matrix
  232. # Input sanitization
  233. d = sympify(d)
  234. if not (d.is_Integer and d.is_positive):
  235. raise ValueError("Spline degree must be a positive integer, not %s." % d)
  236. if len(X) != len(Y):
  237. raise ValueError("Number of X and Y coordinates must be the same.")
  238. if len(X) < d + 1:
  239. raise ValueError("Degree must be less than the number of control points.")
  240. if not all(a < b for a, b in zip(X, X[1:])):
  241. raise ValueError("The x-coordinates must be strictly increasing.")
  242. X = [sympify(i) for i in X]
  243. # Evaluating knots value
  244. if d.is_odd:
  245. j = (d + 1) // 2
  246. interior_knots = X[j:-j]
  247. else:
  248. j = d // 2
  249. interior_knots = [
  250. (a + b)/2 for a, b in zip(X[j : -j - 1], X[j + 1 : -j])
  251. ]
  252. knots = [X[0]] * (d + 1) + list(interior_knots) + [X[-1]] * (d + 1)
  253. basis = bspline_basis_set(d, knots, x)
  254. A = [[b.subs(x, v) for b in basis] for v in X]
  255. coeff = linsolve((Matrix(A), Matrix(Y)), symbols("c0:{}".format(len(X)), cls=Dummy))
  256. coeff = list(coeff)[0]
  257. intervals = {c for b in basis for (e, c) in b.args if c != True}
  258. # Sorting the intervals
  259. # ival contains the end-points of each interval
  260. ival = [_ivl(c, x) for c in intervals]
  261. com = zip(ival, intervals)
  262. com = sorted(com, key=lambda x: x[0])
  263. intervals = [y for x, y in com]
  264. basis_dicts = [{c: e for (e, c) in b.args} for b in basis]
  265. spline = []
  266. for i in intervals:
  267. piece = sum(
  268. [c * d.get(i, S.Zero) for (c, d) in zip(coeff, basis_dicts)], S.Zero
  269. )
  270. spline.append((piece, i))
  271. return Piecewise(*spline)