intervaltree.pxi.in 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427
  1. """
  2. Template for intervaltree
  3. WARNING: DO NOT edit .pxi FILE directly, .pxi is generated from .pxi.in
  4. """
  5. from pandas._libs.algos import is_monotonic
  6. ctypedef fused int_scalar_t:
  7. int64_t
  8. float64_t
  9. ctypedef fused uint_scalar_t:
  10. uint64_t
  11. float64_t
  12. ctypedef fused scalar_t:
  13. int_scalar_t
  14. uint_scalar_t
  15. # ----------------------------------------------------------------------
  16. # IntervalTree
  17. # ----------------------------------------------------------------------
  18. cdef class IntervalTree(IntervalMixin):
  19. """A centered interval tree
  20. Based off the algorithm described on Wikipedia:
  21. https://en.wikipedia.org/wiki/Interval_tree
  22. we are emulating the IndexEngine interface
  23. """
  24. cdef readonly:
  25. ndarray left, right
  26. IntervalNode root
  27. object dtype
  28. str closed
  29. object _is_overlapping, _left_sorter, _right_sorter
  30. Py_ssize_t _na_count
  31. def __init__(self, left, right, closed='right', leaf_size=100):
  32. """
  33. Parameters
  34. ----------
  35. left, right : np.ndarray[ndim=1]
  36. Left and right bounds for each interval. Assumed to contain no
  37. NaNs.
  38. closed : {'left', 'right', 'both', 'neither'}, optional
  39. Whether the intervals are closed on the left-side, right-side, both
  40. or neither. Defaults to 'right'.
  41. leaf_size : int, optional
  42. Parameter that controls when the tree switches from creating nodes
  43. to brute-force search. Tune this parameter to optimize query
  44. performance.
  45. """
  46. if closed not in ['left', 'right', 'both', 'neither']:
  47. raise ValueError("invalid option for 'closed': %s" % closed)
  48. left = np.asarray(left)
  49. right = np.asarray(right)
  50. self.dtype = np.result_type(left, right)
  51. self.left = np.asarray(left, dtype=self.dtype)
  52. self.right = np.asarray(right, dtype=self.dtype)
  53. indices = np.arange(len(left), dtype='int64')
  54. self.closed = closed
  55. # GH 23352: ensure no nan in nodes
  56. mask = ~np.isnan(self.left)
  57. self._na_count = len(mask) - mask.sum()
  58. self.left = self.left[mask]
  59. self.right = self.right[mask]
  60. indices = indices[mask]
  61. node_cls = NODE_CLASSES[str(self.dtype), closed]
  62. self.root = node_cls(self.left, self.right, indices, leaf_size)
  63. @property
  64. def left_sorter(self) -> np.ndarray:
  65. """How to sort the left labels; this is used for binary search
  66. """
  67. if self._left_sorter is None:
  68. self._left_sorter = np.argsort(self.left)
  69. return self._left_sorter
  70. @property
  71. def right_sorter(self) -> np.ndarray:
  72. """How to sort the right labels
  73. """
  74. if self._right_sorter is None:
  75. self._right_sorter = np.argsort(self.right)
  76. return self._right_sorter
  77. @property
  78. def is_overlapping(self) -> bool:
  79. """
  80. Determine if the IntervalTree contains overlapping intervals.
  81. Cached as self._is_overlapping.
  82. """
  83. if self._is_overlapping is not None:
  84. return self._is_overlapping
  85. # <= when both sides closed since endpoints can overlap
  86. op = le if self.closed == 'both' else lt
  87. # overlap if start of current interval < end of previous interval
  88. # (current and previous in terms of sorted order by left/start side)
  89. current = self.left[self.left_sorter[1:]]
  90. previous = self.right[self.left_sorter[:-1]]
  91. self._is_overlapping = bool(op(current, previous).any())
  92. return self._is_overlapping
  93. @property
  94. def is_monotonic_increasing(self) -> bool:
  95. """
  96. Return True if the IntervalTree is monotonic increasing (only equal or
  97. increasing values), else False
  98. """
  99. if self._na_count > 0:
  100. return False
  101. values = [self.right, self.left]
  102. sort_order = np.lexsort(values)
  103. return is_monotonic(sort_order, False)[0]
  104. def get_indexer(self, scalar_t[:] target) -> np.ndarray:
  105. """Return the positions corresponding to unique intervals that overlap
  106. with the given array of scalar targets.
  107. """
  108. # TODO: write get_indexer_intervals
  109. cdef:
  110. Py_ssize_t old_len
  111. Py_ssize_t i
  112. Int64Vector result
  113. result = Int64Vector()
  114. old_len = 0
  115. for i in range(len(target)):
  116. try:
  117. self.root.query(result, target[i])
  118. except OverflowError:
  119. # overflow -> no match, which is already handled below
  120. pass
  121. if result.data.n == old_len:
  122. result.append(-1)
  123. elif result.data.n > old_len + 1:
  124. raise KeyError(
  125. 'indexer does not intersect a unique set of intervals')
  126. old_len = result.data.n
  127. return result.to_array().astype('intp')
  128. def get_indexer_non_unique(self, scalar_t[:] target):
  129. """Return the positions corresponding to intervals that overlap with
  130. the given array of scalar targets. Non-unique positions are repeated.
  131. """
  132. cdef:
  133. Py_ssize_t old_len
  134. Py_ssize_t i
  135. Int64Vector result, missing
  136. result = Int64Vector()
  137. missing = Int64Vector()
  138. old_len = 0
  139. for i in range(len(target)):
  140. try:
  141. self.root.query(result, target[i])
  142. except OverflowError:
  143. # overflow -> no match, which is already handled below
  144. pass
  145. if result.data.n == old_len:
  146. result.append(-1)
  147. missing.append(i)
  148. old_len = result.data.n
  149. return (result.to_array().astype('intp'),
  150. missing.to_array().astype('intp'))
  151. def __repr__(self) -> str:
  152. return ('<IntervalTree[{dtype},{closed}]: '
  153. '{n_elements} elements>'.format(
  154. dtype=self.dtype, closed=self.closed,
  155. n_elements=self.root.n_elements))
  156. # compat with IndexEngine interface
  157. def clear_mapping(self) -> None:
  158. pass
  159. cdef take(ndarray source, ndarray indices):
  160. """Take the given positions from a 1D ndarray
  161. """
  162. return PyArray_Take(source, indices, 0)
  163. cdef sort_values_and_indices(all_values, all_indices, subset):
  164. indices = take(all_indices, subset)
  165. values = take(all_values, subset)
  166. sorter = PyArray_ArgSort(values, 0, NPY_QUICKSORT)
  167. sorted_values = take(values, sorter)
  168. sorted_indices = take(indices, sorter)
  169. return sorted_values, sorted_indices
  170. # ----------------------------------------------------------------------
  171. # Nodes
  172. # ----------------------------------------------------------------------
  173. @cython.internal
  174. cdef class IntervalNode:
  175. cdef readonly:
  176. int64_t n_elements, n_center, leaf_size
  177. bint is_leaf_node
  178. def __repr__(self) -> str:
  179. if self.is_leaf_node:
  180. return (
  181. f"<{type(self).__name__}: {self.n_elements} elements (terminal)>"
  182. )
  183. else:
  184. n_left = self.left_node.n_elements
  185. n_right = self.right_node.n_elements
  186. n_center = self.n_elements - n_left - n_right
  187. return (
  188. f"<{type(self).__name__}: "
  189. f"pivot {self.pivot}, {self.n_elements} elements "
  190. f"({n_left} left, {n_right} right, {n_center} overlapping)>"
  191. )
  192. def counts(self):
  193. """
  194. Inspect counts on this node
  195. useful for debugging purposes
  196. """
  197. if self.is_leaf_node:
  198. return self.n_elements
  199. else:
  200. m = len(self.center_left_values)
  201. l = self.left_node.counts()
  202. r = self.right_node.counts()
  203. return (m, (l, r))
  204. # we need specialized nodes and leaves to optimize for different dtype and
  205. # closed values
  206. {{py:
  207. nodes = []
  208. for dtype in ['float64', 'int64', 'uint64']:
  209. for closed, cmp_left, cmp_right in [
  210. ('left', '<=', '<'),
  211. ('right', '<', '<='),
  212. ('both', '<=', '<='),
  213. ('neither', '<', '<')]:
  214. cmp_left_converse = '<' if cmp_left == '<=' else '<='
  215. cmp_right_converse = '<' if cmp_right == '<=' else '<='
  216. if dtype.startswith('int'):
  217. fused_prefix = 'int_'
  218. elif dtype.startswith('uint'):
  219. fused_prefix = 'uint_'
  220. elif dtype.startswith('float'):
  221. fused_prefix = ''
  222. nodes.append((dtype, dtype.title(),
  223. closed, closed.title(),
  224. cmp_left,
  225. cmp_right,
  226. cmp_left_converse,
  227. cmp_right_converse,
  228. fused_prefix))
  229. }}
  230. NODE_CLASSES = {}
  231. {{for dtype, dtype_title, closed, closed_title, cmp_left, cmp_right,
  232. cmp_left_converse, cmp_right_converse, fused_prefix in nodes}}
  233. @cython.internal
  234. cdef class {{dtype_title}}Closed{{closed_title}}IntervalNode(IntervalNode):
  235. """Non-terminal node for an IntervalTree
  236. Categorizes intervals by those that fall to the left, those that fall to
  237. the right, and those that overlap with the pivot.
  238. """
  239. cdef readonly:
  240. {{dtype_title}}Closed{{closed_title}}IntervalNode left_node, right_node
  241. {{dtype}}_t[:] center_left_values, center_right_values, left, right
  242. int64_t[:] center_left_indices, center_right_indices, indices
  243. {{dtype}}_t min_left, max_right
  244. {{dtype}}_t pivot
  245. def __init__(self,
  246. ndarray[{{dtype}}_t, ndim=1] left,
  247. ndarray[{{dtype}}_t, ndim=1] right,
  248. ndarray[int64_t, ndim=1] indices,
  249. int64_t leaf_size):
  250. self.n_elements = len(left)
  251. self.leaf_size = leaf_size
  252. # min_left and min_right are used to speed-up query by skipping
  253. # query on sub-nodes. If this node has size 0, query is cheap,
  254. # so these values don't matter.
  255. if left.size > 0:
  256. self.min_left = left.min()
  257. self.max_right = right.max()
  258. else:
  259. self.min_left = 0
  260. self.max_right = 0
  261. if self.n_elements <= leaf_size:
  262. # make this a terminal (leaf) node
  263. self.is_leaf_node = True
  264. self.left = left
  265. self.right = right
  266. self.indices = indices
  267. self.n_center = 0
  268. else:
  269. # calculate a pivot so we can create child nodes
  270. self.is_leaf_node = False
  271. self.pivot = np.median(left / 2 + right / 2)
  272. left_set, right_set, center_set = self.classify_intervals(
  273. left, right)
  274. self.left_node = self.new_child_node(left, right,
  275. indices, left_set)
  276. self.right_node = self.new_child_node(left, right,
  277. indices, right_set)
  278. self.center_left_values, self.center_left_indices = \
  279. sort_values_and_indices(left, indices, center_set)
  280. self.center_right_values, self.center_right_indices = \
  281. sort_values_and_indices(right, indices, center_set)
  282. self.n_center = len(self.center_left_indices)
  283. @cython.wraparound(False)
  284. @cython.boundscheck(False)
  285. cdef classify_intervals(self, {{dtype}}_t[:] left, {{dtype}}_t[:] right):
  286. """Classify the given intervals based upon whether they fall to the
  287. left, right, or overlap with this node's pivot.
  288. """
  289. cdef:
  290. Int64Vector left_ind, right_ind, overlapping_ind
  291. Py_ssize_t i
  292. left_ind = Int64Vector()
  293. right_ind = Int64Vector()
  294. overlapping_ind = Int64Vector()
  295. for i in range(self.n_elements):
  296. if right[i] {{cmp_right_converse}} self.pivot:
  297. left_ind.append(i)
  298. elif self.pivot {{cmp_left_converse}} left[i]:
  299. right_ind.append(i)
  300. else:
  301. overlapping_ind.append(i)
  302. return (left_ind.to_array(),
  303. right_ind.to_array(),
  304. overlapping_ind.to_array())
  305. cdef new_child_node(self,
  306. ndarray[{{dtype}}_t, ndim=1] left,
  307. ndarray[{{dtype}}_t, ndim=1] right,
  308. ndarray[int64_t, ndim=1] indices,
  309. ndarray[int64_t, ndim=1] subset):
  310. """Create a new child node.
  311. """
  312. left = take(left, subset)
  313. right = take(right, subset)
  314. indices = take(indices, subset)
  315. return {{dtype_title}}Closed{{closed_title}}IntervalNode(
  316. left, right, indices, self.leaf_size)
  317. @cython.wraparound(False)
  318. @cython.boundscheck(False)
  319. @cython.initializedcheck(False)
  320. cpdef query(self, Int64Vector result, {{fused_prefix}}scalar_t point):
  321. """Recursively query this node and its sub-nodes for intervals that
  322. overlap with the query point.
  323. """
  324. cdef:
  325. int64_t[:] indices
  326. {{dtype}}_t[:] values
  327. Py_ssize_t i
  328. if self.is_leaf_node:
  329. # Once we get down to a certain size, it doesn't make sense to
  330. # continue the binary tree structure. Instead, we use linear
  331. # search.
  332. for i in range(self.n_elements):
  333. if self.left[i] {{cmp_left}} point {{cmp_right}} self.right[i]:
  334. result.append(self.indices[i])
  335. else:
  336. # There are child nodes. Based on comparing our query to the pivot,
  337. # look at the center values, then go to the relevant child.
  338. if point < self.pivot:
  339. values = self.center_left_values
  340. indices = self.center_left_indices
  341. for i in range(self.n_center):
  342. if not values[i] {{cmp_left}} point:
  343. break
  344. result.append(indices[i])
  345. if point {{cmp_right}} self.left_node.max_right:
  346. self.left_node.query(result, point)
  347. elif point > self.pivot:
  348. values = self.center_right_values
  349. indices = self.center_right_indices
  350. for i in range(self.n_center - 1, -1, -1):
  351. if not point {{cmp_right}} values[i]:
  352. break
  353. result.append(indices[i])
  354. if self.right_node.min_left {{cmp_left}} point:
  355. self.right_node.query(result, point)
  356. else:
  357. result.extend(self.center_left_indices)
  358. NODE_CLASSES['{{dtype}}',
  359. '{{closed}}'] = {{dtype_title}}Closed{{closed_title}}IntervalNode
  360. {{endfor}}