Skip to content

Commit 3c1d19b

Browse files
committed
Fix kdtree.IndexOf, add kNearestIndex
Bugfixes, indexof was flawed. Add kNearestIndex to return indices for .data rather than the element itself.
1 parent 8140cac commit 3c1d19b

2 files changed

Lines changed: 175 additions & 26 deletions

File tree

Source/script/imports/simba.import_kdtree.pas

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ function TKDTree.Copy(): TKDTree;
5656
procedure TKDTree.Init(const AData: TKDItems);
5757
function TKDTree.IndexOf(const Value: TSingleArray): Integer;
5858
function TKDTree.KNearest(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TKDItems;
59+
function TKDTree.KNearestIndex(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
5960
function TKDTree.RangeQuery(Low, High: TSingleArray): TKDItems;
6061
function TKDTree.RangeQueryEx(Center: TSingleArray; Radii: TSingleArray; Hide: Boolean): TKDItems;
6162
function TKDTree.KNearestClassify(Vector: TSingleArray; K: Integer): Integer;
@@ -132,6 +133,20 @@ procedure _LapeKDTreeKNearest(const Params: PParamArray; const Result: Pointer);
132133
TKDItems(Result^) := TKDTree(Params^[0]^).KNearest(TSingleArray(Params^[1]^), Integer(Params^[2]^), Boolean(Params^[3]^));
133134
end;
134135

136+
(*
137+
TKDTree.KNearestIndex
138+
---------------------
139+
```
140+
function TKDTree.KNearestIndex(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
141+
```
142+
143+
Returns an array that with indices that can be used to access the kdtree data directly if needed. As an alterantive to getting the vector itself.
144+
*)
145+
procedure _LapeKDTreeKNearestIndex(const Params: PParamArray; const Result: Pointer); LAPE_WRAPPER_CALLING_CONV
146+
begin
147+
TIntegerArray(Result^) := TKDTree(Params^[0]^).KNearestIndex(TSingleArray(Params^[1]^), Integer(Params^[2]^), Boolean(Params^[3]^));
148+
end;
149+
135150
procedure _LapeKDTreeRangeQuery(const Params: PParamArray; const Result: Pointer); LAPE_WRAPPER_CALLING_CONV
136151
begin
137152
TKDItems(Result^) := TKDTree(Params^[0]^).RangeQuery(TSingleArray(Params^[1]^), TSingleArray(Params^[2]^));
@@ -223,6 +238,7 @@ procedure ImportKDTree(Script: TSimbaScript);
223238
addGlobalFunc('function TKDTree.Copy(): TKDTree;', @_LapeKDTreeCopy);
224239
addGlobalFunc('function TKDTree.IndexOf(const Value: TSingleArray): Integer;', @_LapeKDTreeIndexOf);
225240
addGlobalFunc('function TKDTree.KNearest(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TKDItems;', @_LapeKDTreeKNearest);
241+
addGlobalFunc('function TKDTree.KNearestIndex(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;', @_LapeKDTreeKNearestIndex);
226242
addGlobalFunc('function TKDTree.RangeQuery(Low, High: TSingleArray): TKDItems;', @_LapeKDTreeRangeQuery);
227243
addGlobalFunc('function TKDTree.RangeQueryEx(Center: TSingleArray; Radii: TSingleArray; Hide: Boolean): TKDItems;', @_LapeKDTreeRangeQueryEx);
228244
addGlobalFunc('function TKDTree.KNearestClassify(Vector: TSingleArray; K: Integer): Integer;', @_LapeKDTreeKNearestClassify);

Source/simba.container_kdtree.pas

Lines changed: 159 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ TKDTree = record
5656
function SqDistance(A, B: TSingleArray; Limit: Single = High(UInt32)): Single; inline;
5757
function IndexOf(const Value: TSingleArray): Integer;
5858
function KNearest(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TKDItems;
59+
function KNearestIndex(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
5960
function RangeQuery(Low, High: TSingleArray): TKDItems;
6061
function RangeQueryEx(Center: TSingleArray; Radii: TSingleArray; Hide: Boolean): TKDItems;
6162
function KNearestClassify(Vector: TSingleArray; K: Integer): Integer;
@@ -74,6 +75,13 @@ implementation
7475
const
7576
NONE = -1;
7677

78+
function FloatEqual(const A, B: Single): Boolean;
79+
const
80+
EPS = 1e-6;
81+
begin
82+
Result := Abs(A - B) <= EPS * Max(1.0, Max(Abs(A), Abs(B)));
83+
end;
84+
7785

7886
(*
7987
Quick select for the KDTree build process
@@ -275,46 +283,56 @@ function TKDTree.SaveToFile(FileName: string): Boolean;
275283
(*
276284
Returns an Index that can be used to access TKDTree.Data directly.
277285
278-
Average Time Complexity: O(log n)
279-
280-
Note:
281-
The time complexity is typically closer to O(log n) when K is significantly smaller than n.
286+
Average Time Complexity: O(log n), worst case O(n)
282287
*)
283288
function TKDTree.IndexOf(const Value: TSingleArray): Integer;
284-
var
285-
Node: Integer;
286-
Depth: UInt8;
287-
Axis: Integer;
288-
i: Integer;
289-
Found: Boolean;
290-
begin
291-
Node := 0;
292-
Depth := 0;
293-
Result := NONE;
294289

295-
while Node <> NONE do
290+
const
291+
EPS = 0.000001;
292+
293+
function FloatEqual(const A, B: Single): Boolean;
294+
begin
295+
Result := Abs(A - B) <= EPS;
296+
end;
297+
298+
function FindNode(Node: Integer; Depth: Integer): Integer;
299+
var
300+
Axis, i: Integer;
301+
SplitVal, Val: Single;
302+
Equal: Boolean;
296303
begin
297-
Found := True;
304+
if Node = NONE then
305+
Exit(NONE);
306+
307+
Equal := True;
298308
for i := 0 to Self.Dimensions - 1 do
299-
begin
300-
if Self.Data[Node].Split.Vector[i] <> Value[i] then
309+
if not FloatEqual(Self.Data[Node].Split.Vector[i], Value[i]) then
301310
begin
302-
Found := False;
311+
Equal := False;
303312
Break;
304313
end;
305-
end;
306314

307-
if Found then
315+
if Equal then
308316
Exit(Node);
309317

310318
Axis := Depth mod Self.Dimensions;
311-
if Value[Axis] < Self.Data[Node].Split.Vector[Axis] then
312-
Node := Self.Data[Node].L
313-
else
314-
Node := Self.Data[Node].R;
319+
SplitVal := Self.Data[Node].Split.Vector[Axis];
320+
Val := Value[Axis];
315321

316-
Inc(Depth);
322+
if FloatEqual(Val, SplitVal) then
323+
begin
324+
Result := FindNode(Self.Data[Node].L, Depth + 1);
325+
if Result = NONE then
326+
Result := FindNode(Self.Data[Node].R, Depth + 1);
327+
end
328+
else if Val < SplitVal then
329+
Result := FindNode(Self.Data[Node].L, Depth + 1)
330+
else
331+
Result := FindNode(Self.Data[Node].R, Depth + 1);
317332
end;
333+
334+
begin
335+
Result := FindNode(0, 0);
318336
end;
319337

320338

@@ -444,6 +462,121 @@ TNearestItem = record
444462
end;
445463
end;
446464

465+
// same as the above, but returns an array of indices.
466+
function TKDTree.KNearestIndex(Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
467+
type
468+
TNearestItem = record
469+
Node: Integer;
470+
DistSq: Single;
471+
end;
472+
TNearestHeap = array of TNearestItem;
473+
474+
var
475+
Heap: TNearestHeap;
476+
477+
procedure Heapify(Index: Integer);
478+
var
479+
Largest: Integer;
480+
Left, Right: Integer;
481+
Temp: TNearestItem;
482+
begin
483+
Largest := Index;
484+
Left := 2 * Index + 1;
485+
Right := 2 * Index + 2;
486+
487+
if (Left < Length(Heap)) and (Heap[Left].DistSq > Heap[Largest].DistSq) then
488+
Largest := Left;
489+
490+
if (Right < Length(Heap)) and (Heap[Right].DistSq > Heap[Largest].DistSq) then
491+
Largest := Right;
492+
493+
if Largest <> Index then
494+
begin
495+
Temp := Heap[Index];
496+
Heap[Index] := Heap[Largest];
497+
Heap[Largest] := Temp;
498+
Heapify(Largest);
499+
end;
500+
end;
501+
502+
procedure FindKNearest(Node: Integer; Depth: UInt8);
503+
var
504+
Delta, DistSq: Single;
505+
Test: Integer;
506+
This: PKDNode;
507+
Axis: Integer;
508+
Temp: TNearestItem;
509+
I: Integer;
510+
begin
511+
if Node = NONE then Exit;
512+
513+
This := @Self.Data[Node];
514+
Axis := Depth mod Self.Dimensions;
515+
516+
Delta := This^.Split.Vector[Axis] - Vector[Axis];
517+
518+
if Length(Heap) < K then
519+
DistSq := Self.SqDistance(This^.Split.Vector, Vector, High(UInt32)) // No limit if heap is not full
520+
else
521+
DistSq := Self.SqDistance(This^.Split.Vector, Vector, Heap[0].DistSq); // Limit is the furthest distance in the heap
522+
523+
if not((DistSq = 0) and NotEqual) then
524+
begin
525+
if Length(Heap) < K then
526+
begin
527+
// heap not full, add current node
528+
SetLength(Heap, Length(Heap) + 1);
529+
Heap[High(Heap)].Node := Node;
530+
Heap[High(Heap)].DistSq := DistSq;
531+
532+
// heapify upwards
533+
i := High(Heap);
534+
while (i > 0) and (Heap[(i - 1) div 2].DistSq < Heap[i].DistSq) do
535+
begin
536+
Temp := Heap[i];
537+
Heap[i] := Heap[(i - 1) div 2];
538+
Heap[(i - 1) div 2] := Temp;
539+
i := (i - 1) div 2;
540+
end;
541+
end
542+
else
543+
begin
544+
if DistSq < Heap[0].DistSq then
545+
begin
546+
// replace the furthest node with the current node
547+
Heap[0].Node := Node;
548+
Heap[0].DistSq := DistSq;
549+
// heapify downwards
550+
Heapify(0);
551+
end;
552+
end;
553+
end;
554+
555+
if Delta > 0 then Test := This^.L else Test := This^.R;
556+
FindKNearest(Test, Depth + 1);
557+
558+
if (Length(Heap) < K) or (Sqr(Delta) < Heap[0].DistSq) then
559+
begin
560+
if Delta > 0 then Test := This^.R else Test := This^.L;
561+
FindKNearest(Test, Depth + 1);
562+
end;
563+
end;
564+
565+
var
566+
i,j: Integer;
567+
begin
568+
SetLength(Heap, 0);
569+
FindKNearest(0, 0);
570+
571+
SetLength(Result, Length(Heap));
572+
j := 0;
573+
for i := High(Heap) downto 0 do
574+
begin
575+
Result[j] := Heap[i].Node;
576+
Inc(j);
577+
end;
578+
end;
579+
447580

448581
(*
449582
Returns the vectors that within the hyperrectangle defined by low and high vectors.

0 commit comments

Comments
 (0)