OpenTTD
kdtree.hpp
Go to the documentation of this file.
1 /*
2  * This file is part of OpenTTD.
3  * OpenTTD is free software; you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, version 2.
4  * OpenTTD is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.
5  * See the GNU General Public License for more details. You should have received a copy of the GNU General Public License along with OpenTTD. If not, see <http://www.gnu.org/licenses/>.
6  */
7 
10 #ifndef KDTREE_HPP
11 #define KDTREE_HPP
12 
13 #include "../stdafx.h"
14 #include <vector>
15 #include <algorithm>
16 #include <limits>
17 
37 template <typename T, typename TxyFunc, typename CoordT, typename DistT>
38 class Kdtree {
40  struct node {
41  T element;
42  size_t left;
43  size_t right;
44 
45  node(T element) : element(element), left(INVALID_NODE), right(INVALID_NODE) { }
46  };
47 
48  static const size_t INVALID_NODE = SIZE_MAX;
49 
50  std::vector<node> nodes;
51  std::vector<size_t> free_list;
52  size_t root;
53  TxyFunc xyfunc;
54  size_t unbalanced;
55 
57  size_t AddNode(const T &element)
58  {
59  if (this->free_list.size() == 0) {
60  this->nodes.emplace_back(element);
61  return this->nodes.size() - 1;
62  } else {
63  size_t newidx = this->free_list.back();
64  this->free_list.pop_back();
65  this->nodes[newidx] = node{ element };
66  return newidx;
67  }
68  }
69 
71  template <typename It>
72  CoordT SelectSplitCoord(It begin, It end, int level)
73  {
74  It mid = begin + (end - begin) / 2;
75  std::nth_element(begin, mid, end, [&](T a, T b) { return this->xyfunc(a, level % 2) < this->xyfunc(b, level % 2); });
76  return this->xyfunc(*mid, level % 2);
77  }
78 
80  template <typename It>
81  size_t BuildSubtree(It begin, It end, int level)
82  {
83  ptrdiff_t count = end - begin;
84 
85  if (count == 0) {
86  return INVALID_NODE;
87  } else if (count == 1) {
88  return this->AddNode(*begin);
89  } else if (count > 1) {
90  CoordT split_coord = SelectSplitCoord(begin, end, level);
91  It split = std::partition(begin, end, [&](T v) { return this->xyfunc(v, level % 2) < split_coord; });
92  size_t newidx = this->AddNode(*split);
93  this->nodes[newidx].left = this->BuildSubtree(begin, split, level + 1);
94  this->nodes[newidx].right = this->BuildSubtree(split + 1, end, level + 1);
95  return newidx;
96  } else {
97  NOT_REACHED();
98  }
99  }
100 
102  bool Rebuild(const T *include_element, const T *exclude_element)
103  {
104  size_t initial_count = this->Count();
105  if (initial_count < 8) return false; // arbitrary value for "not worth rebalancing"
106 
107  T root_element = this->nodes[this->root].element;
108  std::vector<T> elements = this->FreeSubtree(this->root);
109  elements.push_back(root_element);
110 
111  if (include_element != nullptr) {
112  elements.push_back(*include_element);
113  initial_count++;
114  }
115  if (exclude_element != nullptr) {
116  typename std::vector<T>::iterator removed = std::remove(elements.begin(), elements.end(), *exclude_element);
117  elements.erase(removed, elements.end());
118  initial_count--;
119  }
120 
121  this->Build(elements.begin(), elements.end());
122  assert(initial_count == this->Count());
123  return true;
124  }
125 
127  void InsertRecursive(const T &element, size_t node_idx, int level)
128  {
129  /* Dimension index of current level */
130  int dim = level % 2;
131  /* Node reference */
132  node &n = this->nodes[node_idx];
133 
134  /* Coordinate of element splitting at this node */
135  CoordT nc = this->xyfunc(n.element, dim);
136  /* Coordinate of the new element */
137  CoordT ec = this->xyfunc(element, dim);
138  /* Which side to insert on */
139  size_t &next = (ec < nc) ? n.left : n.right;
140 
141  if (next == INVALID_NODE) {
142  /* New leaf */
143  size_t newidx = this->AddNode(element);
144  /* Vector may have been reallocated at this point, n and next are invalid */
145  node &nn = this->nodes[node_idx];
146  if (ec < nc) nn.left = newidx; else nn.right = newidx;
147  } else {
148  this->InsertRecursive(element, next, level + 1);
149  }
150  }
151 
156  std::vector<T> FreeSubtree(size_t node_idx)
157  {
158  std::vector<T> subtree_elements;
159  node &n = this->nodes[node_idx];
160 
161  /* We'll be appending items to the free_list, get index of our first item */
162  size_t first_free = this->free_list.size();
163  /* Prepare the descent with our children */
164  if (n.left != INVALID_NODE) this->free_list.push_back(n.left);
165  if (n.right != INVALID_NODE) this->free_list.push_back(n.right);
166  n.left = n.right = INVALID_NODE;
167 
168  /* Recursively free the nodes being collected */
169  for (size_t i = first_free; i < this->free_list.size(); i++) {
170  node &fn = this->nodes[this->free_list[i]];
171  subtree_elements.push_back(fn.element);
172  if (fn.left != INVALID_NODE) this->free_list.push_back(fn.left);
173  if (fn.right != INVALID_NODE) this->free_list.push_back(fn.right);
174  fn.left = fn.right = INVALID_NODE;
175  }
176 
177  return subtree_elements;
178  }
179 
187  size_t RemoveRecursive(const T &element, size_t node_idx, int level)
188  {
189  /* Node reference */
190  node &n = this->nodes[node_idx];
191 
192  if (n.element == element) {
193  /* Remove this one */
194  this->free_list.push_back(node_idx);
195  if (n.left == INVALID_NODE && n.right == INVALID_NODE) {
196  /* Simple case, leaf, new child node for parent is "none" */
197  return INVALID_NODE;
198  } else {
199  /* Complex case, rebuild the sub-tree */
200  std::vector<T> subtree_elements = this->FreeSubtree(node_idx);
201  return this->BuildSubtree(subtree_elements.begin(), subtree_elements.end(), level);;
202  }
203  } else {
204  /* Search in a sub-tree */
205  /* Dimension index of current level */
206  int dim = level % 2;
207  /* Coordinate of element splitting at this node */
208  CoordT nc = this->xyfunc(n.element, dim);
209  /* Coordinate of the element being removed */
210  CoordT ec = this->xyfunc(element, dim);
211  /* Which side to remove from */
212  size_t next = (ec < nc) ? n.left : n.right;
213  assert(next != INVALID_NODE); // node must exist somewhere and must be found before a leaf is reached
214  /* Descend */
215  size_t new_branch = this->RemoveRecursive(element, next, level + 1);
216  if (new_branch != next) {
217  /* Vector may have been reallocated at this point, n and next are invalid */
218  node &nn = this->nodes[node_idx];
219  if (ec < nc) nn.left = new_branch; else nn.right = new_branch;
220  }
221  return node_idx;
222  }
223  }
224 
225 
226  DistT ManhattanDistance(const T &element, CoordT x, CoordT y) const
227  {
228  return abs((DistT)this->xyfunc(element, 0) - (DistT)x) + abs((DistT)this->xyfunc(element, 1) - (DistT)y);
229  }
230 
232  using node_distance = std::pair<T, DistT>;
235  {
236  if (a.second < b.second) return a;
237  if (b.second < a.second) return b;
238  if (a.first < b.first) return a;
239  if (b.first < a.first) return b;
240  NOT_REACHED(); // a.first == b.first: same element must not be inserted twice
241  }
243  node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit = std::numeric_limits<DistT>::max()) const
244  {
245  /* Dimension index of current level */
246  int dim = level % 2;
247  /* Node reference */
248  const node &n = this->nodes[node_idx];
249 
250  /* Coordinate of element splitting at this node */
251  CoordT c = this->xyfunc(n.element, dim);
252  /* This node's distance to target */
253  DistT thisdist = ManhattanDistance(n.element, xy[0], xy[1]);
254  /* Assume this node is the best choice for now */
255  node_distance best = std::make_pair(n.element, thisdist);
256 
257  /* Next node to visit */
258  size_t next = (xy[dim] < c) ? n.left : n.right;
259  if (next != INVALID_NODE) {
260  /* Check if there is a better node down the tree */
261  best = SelectNearestNodeDistance(best, this->FindNearestRecursive(xy, next, level + 1));
262  }
263 
264  limit = min(best.second, limit);
265 
266  /* Check if the distance from current best is worse than distance from target to splitting line,
267  * if it is we also need to check the other side of the split. */
268  size_t opposite = (xy[dim] >= c) ? n.left : n.right; // reverse of above
269  if (opposite != INVALID_NODE && limit >= abs((int)xy[dim] - (int)c)) {
270  node_distance other_candidate = this->FindNearestRecursive(xy, opposite, level + 1, limit);
271  best = SelectNearestNodeDistance(best, other_candidate);
272  }
273 
274  return best;
275  }
276 
277  template <typename Outputter>
278  void FindContainedRecursive(CoordT p1[2], CoordT p2[2], size_t node_idx, int level, Outputter outputter) const
279  {
280  /* Dimension index of current level */
281  int dim = level % 2;
282  /* Node reference */
283  const node &n = this->nodes[node_idx];
284 
285  /* Coordinate of element splitting at this node */
286  CoordT ec = this->xyfunc(n.element, dim);
287  /* Opposite coordinate of element */
288  CoordT oc = this->xyfunc(n.element, 1 - dim);
289 
290  /* Test if this element is within rectangle */
291  if (ec >= p1[dim] && ec < p2[dim] && oc >= p1[1 - dim] && oc < p2[1 - dim]) outputter(n.element);
292 
293  /* Recurse left if part of rectangle is left of split */
294  if (p1[dim] < ec && n.left != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.left, level + 1, outputter);
295 
296  /* Recurse right if part of rectangle is right of split */
297  if (p2[dim] > ec && n.right != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.right, level + 1, outputter);
298  }
299 
301  size_t CountValue(const T &element, size_t node_idx) const
302  {
303  if (node_idx == INVALID_NODE) return 0;
304  const node &n = this->nodes[node_idx];
305  return CountValue(element, n.left) + CountValue(element, n.right) + ((n.element == element) ? 1 : 0);
306  }
307 
308  void IncrementUnbalanced(size_t amount = 1)
309  {
310  this->unbalanced += amount;
311  }
312 
315  {
316  size_t count = this->Count();
317  if (count < 8) return false;
318  return this->unbalanced > this->Count() / 4;
319  }
320 
322  void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
323  {
324  if (node_idx == INVALID_NODE) return;
325 
326  const node &n = this->nodes[node_idx];
327  CoordT cx = this->xyfunc(n.element, 0);
328  CoordT cy = this->xyfunc(n.element, 1);
329 
330  assert(cx >= min_x);
331  assert(cx < max_x);
332  assert(cy >= min_y);
333  assert(cy < max_y);
334 
335  if (level % 2 == 0) {
336  // split in dimension 0 = x
337  CheckInvariant(n.left, level + 1, min_x, cx, min_y, max_y);
338  CheckInvariant(n.right, level + 1, cx, max_x, min_y, max_y);
339  } else {
340  // split in dimension 1 = y
341  CheckInvariant(n.left, level + 1, min_x, max_x, min_y, cy);
342  CheckInvariant(n.right, level + 1, min_x, max_x, cy, max_y);
343  }
344  }
345 
348  {
349 #ifdef KDTREE_DEBUG
351 #endif
352  }
353 
354 public:
356  Kdtree(TxyFunc xyfunc) : root(INVALID_NODE), xyfunc(xyfunc), unbalanced(0) { }
357 
364  template <typename It>
365  void Build(It begin, It end)
366  {
367  this->nodes.clear();
368  this->free_list.clear();
369  this->unbalanced = 0;
370  if (begin == end) return;
371  this->nodes.reserve(end - begin);
372 
373  this->root = this->BuildSubtree(begin, end, 0);
374  CheckInvariant();
375  }
376 
380  void Clear()
381  {
382  this->nodes.clear();
383  this->free_list.clear();
384  this->unbalanced = 0;
385  return;
386  }
387 
391  void Rebuild()
392  {
393  this->Rebuild(nullptr, nullptr);
394  }
395 
401  void Insert(const T &element)
402  {
403  if (this->Count() == 0) {
404  this->root = this->AddNode(element);
405  } else {
406  if (!this->IsUnbalanced() || !this->Rebuild(&element, nullptr)) {
407  this->InsertRecursive(element, this->root, 0);
408  this->IncrementUnbalanced();
409  }
410  CheckInvariant();
411  }
412  }
413 
420  void Remove(const T &element)
421  {
422  size_t count = this->Count();
423  if (count == 0) return;
424  if (!this->IsUnbalanced() || !this->Rebuild(nullptr, &element)) {
425  /* If the removed element is the root node, this modifies this->root */
426  this->root = this->RemoveRecursive(element, this->root, 0);
427  this->IncrementUnbalanced();
428  }
429  CheckInvariant();
430  }
431 
433  size_t Count() const
434  {
435  assert(this->free_list.size() <= this->nodes.size());
436  return this->nodes.size() - this->free_list.size();
437  }
438 
444  T FindNearest(CoordT x, CoordT y) const
445  {
446  assert(this->Count() > 0);
447 
448  CoordT xy[2] = { x, y };
449  return this->FindNearestRecursive(xy, this->root, 0).first;
450  }
451 
461  template <typename Outputter>
462  void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter) const
463  {
464  assert(x1 < x2);
465  assert(y1 < y2);
466 
467  if (this->Count() == 0) return;
468 
469  CoordT p1[2] = { x1, y1 };
470  CoordT p2[2] = { x2, y2 };
471  this->FindContainedRecursive(p1, p2, this->root, 0, outputter);
472  }
473 
478  std::vector<T> FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
479  {
480  std::vector<T> result;
481  this->FindContained(x1, y1, x2, y2, [&result](T e) {result.push_back(e); });
482  return result;
483  }
484 };
485 
486 #endif
std::vector< T > FreeSubtree(size_t node_idx)
Free all children of the given node.
Definition: kdtree.hpp:156
void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
Verify that the invariant is true for a sub-tree, assert if not.
Definition: kdtree.hpp:322
std::vector< node > nodes
Pool of all nodes in the tree.
Definition: kdtree.hpp:50
bool Rebuild(const T *include_element, const T *exclude_element)
Rebuild the tree with all existing elements, optionally adding or removing one more.
Definition: kdtree.hpp:102
std::vector< size_t > free_list
List of dead indices in the nodes vector.
Definition: kdtree.hpp:51
static node_distance SelectNearestNodeDistance(const node_distance &a, const node_distance &b)
Ordering function for node_distance objects, elements with equal distance are ordered by less-than co...
Definition: kdtree.hpp:234
size_t RemoveRecursive(const T &element, size_t node_idx, int level)
Find and remove one element from the tree.
Definition: kdtree.hpp:187
void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter) const
Find all items contained within the given rectangle.
Definition: kdtree.hpp:462
T FindNearest(CoordT x, CoordT y) const
Find the element closest to given coordinate, in Manhattan distance.
Definition: kdtree.hpp:444
size_t BuildSubtree(It begin, It end, int level)
Construct a subtree from elements between begin and end iterators, return index of root...
Definition: kdtree.hpp:81
static T max(const T a, const T b)
Returns the maximum of two values.
Definition: math_func.hpp:24
bool IsUnbalanced()
Check if the entire tree is in need of rebuilding.
Definition: kdtree.hpp:314
std::pair< T, DistT > node_distance
A data element and its distance to a searched-for point.
Definition: kdtree.hpp:232
size_t left
Index of node to the left, INVALID_NODE if none.
Definition: kdtree.hpp:42
size_t Count() const
Get number of elements stored in tree.
Definition: kdtree.hpp:433
size_t AddNode(const T &element)
Create one new node in the tree, return its index in the pool.
Definition: kdtree.hpp:57
size_t unbalanced
Number approximating how unbalanced the tree might be.
Definition: kdtree.hpp:54
CoordT SelectSplitCoord(It begin, It end, int level)
Find a coordinate value to split a range of elements at.
Definition: kdtree.hpp:72
static T min(const T a, const T b)
Returns the minimum of two values.
Definition: math_func.hpp:40
size_t right
Index of node to the right, INVALID_NODE if none.
Definition: kdtree.hpp:43
static const size_t INVALID_NODE
Index value indicating no-such-node.
Definition: kdtree.hpp:48
std::vector< T > FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
Find all items contained within the given rectangle.
Definition: kdtree.hpp:478
T element
Element stored at node.
Definition: kdtree.hpp:41
void Build(It begin, It end)
Clear and rebuild the tree from a new sequence of elements,.
Definition: kdtree.hpp:365
void Remove(const T &element)
Remove a single element from the tree, if it exists.
Definition: kdtree.hpp:420
Kdtree(TxyFunc xyfunc)
Construct a new Kdtree with the given xyfunc.
Definition: kdtree.hpp:356
size_t root
Index of root node.
Definition: kdtree.hpp:52
static T abs(const T a)
Returns the absolute value of (scalar) variable.
Definition: math_func.hpp:81
node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit=std::numeric_limits< DistT >::max()) const
Search a sub-tree for the element nearest to a given point.
Definition: kdtree.hpp:243
K-dimensional tree, specialised for 2-dimensional space.
Definition: kdtree.hpp:38
void CheckInvariant()
Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined.
Definition: kdtree.hpp:347
void Clear()
Clear the tree.
Definition: kdtree.hpp:380
size_t CountValue(const T &element, size_t node_idx) const
Debugging function, counts number of occurrences of an element regardless of its correct position in ...
Definition: kdtree.hpp:301
void InsertRecursive(const T &element, size_t node_idx, int level)
Insert one element in the tree somewhere below node_idx.
Definition: kdtree.hpp:127
Type of a node in the tree.
Definition: kdtree.hpp:40
void Rebuild()
Reconstruct the tree with the same elements, letting it be fully balanced.
Definition: kdtree.hpp:391
TxyFunc xyfunc
Functor to extract a coordinate from an element.
Definition: kdtree.hpp:53
void Insert(const T &element)
Insert a single element in the tree.
Definition: kdtree.hpp:401