13 #include "../stdafx.h" 37 template <
typename T,
typename TxyFunc,
typename CoordT,
typename DistT>
59 if (this->free_list.size() == 0) {
60 this->nodes.emplace_back(element);
61 return this->nodes.size() - 1;
63 size_t newidx = this->free_list.back();
64 this->free_list.pop_back();
65 this->nodes[newidx] =
node{ element };
71 template <
typename It>
74 It mid = begin + (end - begin) / 2;
75 std::nth_element(begin, mid, end, [&](T a, T b) {
return this->
xyfunc(a, level % 2) < this->
xyfunc(b, level % 2); });
76 return this->
xyfunc(*mid, level % 2);
80 template <
typename It>
83 ptrdiff_t count = end - begin;
87 }
else if (count == 1) {
89 }
else if (count > 1) {
91 It split = std::partition(begin, end, [&](T v) {
return this->
xyfunc(v, level % 2) < split_coord; });
92 size_t newidx = this->
AddNode(*split);
93 this->nodes[newidx].left = this->
BuildSubtree(begin, split, level + 1);
94 this->nodes[newidx].right = this->
BuildSubtree(split + 1, end, level + 1);
102 bool Rebuild(
const T *include_element,
const T *exclude_element)
104 size_t initial_count = this->
Count();
105 if (initial_count < 8)
return false;
107 T root_element = this->nodes[this->
root].element;
108 std::vector<T> elements = this->
FreeSubtree(this->root);
109 elements.push_back(root_element);
111 if (include_element !=
nullptr) {
112 elements.push_back(*include_element);
115 if (exclude_element !=
nullptr) {
116 typename std::vector<T>::iterator removed = std::remove(elements.begin(), elements.end(), *exclude_element);
117 elements.erase(removed, elements.end());
121 this->
Build(elements.begin(), elements.end());
122 assert(initial_count == this->
Count());
132 node &n = this->nodes[node_idx];
137 CoordT ec = this->
xyfunc(element, dim);
139 size_t &next = (ec < nc) ? n.
left : n.
right;
141 if (next == INVALID_NODE) {
143 size_t newidx = this->
AddNode(element);
145 node &nn = this->nodes[node_idx];
146 if (ec < nc) nn.
left = newidx;
else nn.
right = newidx;
158 std::vector<T> subtree_elements;
159 node &n = this->nodes[node_idx];
162 size_t first_free = this->free_list.size();
164 if (n.
left != INVALID_NODE) this->free_list.push_back(n.
left);
165 if (n.
right != INVALID_NODE) this->free_list.push_back(n.
right);
169 for (
size_t i = first_free; i < this->free_list.size(); i++) {
170 node &fn = this->nodes[this->free_list[i]];
171 subtree_elements.push_back(fn.
element);
172 if (fn.
left != INVALID_NODE) this->free_list.push_back(fn.
left);
173 if (fn.
right != INVALID_NODE) this->free_list.push_back(fn.
right);
177 return subtree_elements;
190 node &n = this->nodes[node_idx];
194 this->free_list.push_back(node_idx);
195 if (n.
left == INVALID_NODE && n.
right == INVALID_NODE) {
200 std::vector<T> subtree_elements = this->
FreeSubtree(node_idx);
201 return this->
BuildSubtree(subtree_elements.begin(), subtree_elements.end(), level);;
210 CoordT ec = this->
xyfunc(element, dim);
212 size_t next = (ec < nc) ? n.
left : n.
right;
213 assert(next != INVALID_NODE);
216 if (new_branch != next) {
218 node &nn = this->nodes[node_idx];
219 if (ec < nc) nn.
left = new_branch;
else nn.
right = new_branch;
226 DistT ManhattanDistance(
const T &
element, CoordT x, CoordT y)
const 228 return abs((DistT)this->
xyfunc(element, 0) - (DistT)x) +
abs((DistT)this->
xyfunc(element, 1) - (DistT)y);
236 if (a.second < b.second)
return a;
237 if (b.second < a.second)
return b;
238 if (a.first < b.first)
return a;
239 if (b.first < a.first)
return b;
248 const node &n = this->nodes[node_idx];
253 DistT thisdist = ManhattanDistance(n.
element, xy[0], xy[1]);
258 size_t next = (xy[dim] < c) ? n.
left : n.
right;
259 if (next != INVALID_NODE) {
264 limit =
min(best.second, limit);
268 size_t opposite = (xy[dim] >= c) ? n.
left : n.
right;
269 if (opposite != INVALID_NODE && limit >=
abs((
int)xy[dim] - (
int)c)) {
277 template <
typename Outputter>
278 void FindContainedRecursive(CoordT p1[2], CoordT p2[2],
size_t node_idx,
int level, Outputter outputter)
const 283 const node &n = this->nodes[node_idx];
291 if (ec >= p1[dim] && ec < p2[dim] && oc >= p1[1 - dim] && oc < p2[1 - dim]) outputter(n.
element);
294 if (p1[dim] < ec && n.
left != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.
left, level + 1, outputter);
297 if (p2[dim] > ec && n.
right != INVALID_NODE) this->FindContainedRecursive(p1, p2, n.
right, level + 1, outputter);
303 if (node_idx == INVALID_NODE)
return 0;
304 const node &n = this->nodes[node_idx];
308 void IncrementUnbalanced(
size_t amount = 1)
310 this->unbalanced += amount;
316 size_t count = this->
Count();
317 if (count < 8)
return false;
318 return this->unbalanced > this->
Count() / 4;
322 void CheckInvariant(
size_t node_idx,
int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
324 if (node_idx == INVALID_NODE)
return;
326 const node &n = this->nodes[node_idx];
335 if (level % 2 == 0) {
356 Kdtree(TxyFunc xyfunc) : root(INVALID_NODE), xyfunc(xyfunc), unbalanced(0) { }
364 template <
typename It>
368 this->free_list.clear();
369 this->unbalanced = 0;
370 if (begin == end)
return;
371 this->nodes.reserve(end - begin);
383 this->free_list.clear();
384 this->unbalanced = 0;
393 this->
Rebuild(
nullptr,
nullptr);
403 if (this->
Count() == 0) {
404 this->root = this->
AddNode(element);
408 this->IncrementUnbalanced();
422 size_t count = this->
Count();
423 if (count == 0)
return;
427 this->IncrementUnbalanced();
435 assert(this->free_list.size() <= this->nodes.size());
436 return this->nodes.size() - this->free_list.size();
446 assert(this->
Count() > 0);
448 CoordT xy[2] = { x, y };
461 template <
typename Outputter>
462 void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter)
const 467 if (this->
Count() == 0)
return;
469 CoordT p1[2] = { x1, y1 };
470 CoordT p2[2] = { x2, y2 };
471 this->FindContainedRecursive(p1, p2, this->root, 0, outputter);
478 std::vector<T>
FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2)
const 480 std::vector<T> result;
481 this->
FindContained(x1, y1, x2, y2, [&result](T e) {result.push_back(e); });
std::vector< T > FreeSubtree(size_t node_idx)
Free all children of the given node.
void CheckInvariant(size_t node_idx, int level, CoordT min_x, CoordT max_x, CoordT min_y, CoordT max_y)
Verify that the invariant is true for a sub-tree, assert if not.
std::vector< node > nodes
Pool of all nodes in the tree.
bool Rebuild(const T *include_element, const T *exclude_element)
Rebuild the tree with all existing elements, optionally adding or removing one more.
std::vector< size_t > free_list
List of dead indices in the nodes vector.
static node_distance SelectNearestNodeDistance(const node_distance &a, const node_distance &b)
Ordering function for node_distance objects, elements with equal distance are ordered by less-than co...
size_t RemoveRecursive(const T &element, size_t node_idx, int level)
Find and remove one element from the tree.
void FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2, Outputter outputter) const
Find all items contained within the given rectangle.
T FindNearest(CoordT x, CoordT y) const
Find the element closest to given coordinate, in Manhattan distance.
size_t BuildSubtree(It begin, It end, int level)
Construct a subtree from elements between begin and end iterators, return index of root...
static T max(const T a, const T b)
Returns the maximum of two values.
bool IsUnbalanced()
Check if the entire tree is in need of rebuilding.
std::pair< T, DistT > node_distance
A data element and its distance to a searched-for point.
size_t left
Index of node to the left, INVALID_NODE if none.
size_t Count() const
Get number of elements stored in tree.
size_t AddNode(const T &element)
Create one new node in the tree, return its index in the pool.
size_t unbalanced
Number approximating how unbalanced the tree might be.
CoordT SelectSplitCoord(It begin, It end, int level)
Find a coordinate value to split a range of elements at.
static T min(const T a, const T b)
Returns the minimum of two values.
size_t right
Index of node to the right, INVALID_NODE if none.
static const size_t INVALID_NODE
Index value indicating no-such-node.
std::vector< T > FindContained(CoordT x1, CoordT y1, CoordT x2, CoordT y2) const
Find all items contained within the given rectangle.
T element
Element stored at node.
void Build(It begin, It end)
Clear and rebuild the tree from a new sequence of elements,.
void Remove(const T &element)
Remove a single element from the tree, if it exists.
Kdtree(TxyFunc xyfunc)
Construct a new Kdtree with the given xyfunc.
size_t root
Index of root node.
static T abs(const T a)
Returns the absolute value of (scalar) variable.
node_distance FindNearestRecursive(CoordT xy[2], size_t node_idx, int level, DistT limit=std::numeric_limits< DistT >::max()) const
Search a sub-tree for the element nearest to a given point.
K-dimensional tree, specialised for 2-dimensional space.
void CheckInvariant()
Verify the invariant for the entire tree, does nothing unless KDTREE_DEBUG is defined.
void Clear()
Clear the tree.
size_t CountValue(const T &element, size_t node_idx) const
Debugging function, counts number of occurrences of an element regardless of its correct position in ...
void InsertRecursive(const T &element, size_t node_idx, int level)
Insert one element in the tree somewhere below node_idx.
Type of a node in the tree.
void Rebuild()
Reconstruct the tree with the same elements, letting it be fully balanced.
TxyFunc xyfunc
Functor to extract a coordinate from an element.
void Insert(const T &element)
Insert a single element in the tree.