@@ -56,6 +56,7 @@ TKDTree = record
5656 function SqDistance (A, B: TSingleArray; Limit: Single = High(UInt32)): Single; inline;
5757 function IndexOf (const Value : TSingleArray): Integer;
5858 function KNearest (Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TKDItems;
59+ function KNearestIndex (Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
5960 function RangeQuery (Low, High: TSingleArray): TKDItems;
6061 function RangeQueryEx (Center: TSingleArray; Radii: TSingleArray; Hide: Boolean): TKDItems;
6162 function KNearestClassify (Vector: TSingleArray; K: Integer): Integer;
@@ -74,6 +75,13 @@ implementation
7475const
7576 NONE = -1 ;
7677
78+ function FloatEqual (const A, B: Single): Boolean;
79+ const
80+ EPS = 1e-6 ;
81+ begin
82+ Result := Abs(A - B) <= EPS * Max(1.0 , Max(Abs(A), Abs(B)));
83+ end ;
84+
7785
7886(*
7987 Quick select for the KDTree build process
@@ -275,46 +283,56 @@ function TKDTree.SaveToFile(FileName: string): Boolean;
275283(*
276284 Returns an Index that can be used to access TKDTree.Data directly.
277285
278- Average Time Complexity: O(log n)
279-
280- Note:
281- The time complexity is typically closer to O(log n) when K is significantly smaller than n.
286+ Average Time Complexity: O(log n), worst case O(n)
282287*)
283288function TKDTree.IndexOf (const Value : TSingleArray): Integer;
284- var
285- Node: Integer;
286- Depth: UInt8;
287- Axis: Integer;
288- i: Integer;
289- Found: Boolean;
290- begin
291- Node := 0 ;
292- Depth := 0 ;
293- Result := NONE;
294289
295- while Node <> NONE do
290+ const
291+ EPS = 0.000001 ;
292+
293+ function FloatEqual (const A, B: Single): Boolean;
294+ begin
295+ Result := Abs(A - B) <= EPS;
296+ end ;
297+
298+ function FindNode (Node: Integer; Depth: Integer): Integer;
299+ var
300+ Axis, i: Integer;
301+ SplitVal, Val: Single;
302+ Equal: Boolean;
296303 begin
297- Found := True;
304+ if Node = NONE then
305+ Exit(NONE);
306+
307+ Equal := True;
298308 for i := 0 to Self.Dimensions - 1 do
299- begin
300- if Self.Data[Node].Split.Vector[i] <> Value [i] then
309+ if not FloatEqual(Self.Data[Node].Split.Vector[i], Value [i]) then
301310 begin
302- Found := False;
311+ Equal := False;
303312 Break;
304313 end ;
305- end ;
306314
307- if Found then
315+ if Equal then
308316 Exit(Node);
309317
310318 Axis := Depth mod Self.Dimensions;
311- if Value [Axis] < Self.Data[Node].Split.Vector[Axis] then
312- Node := Self.Data[Node].L
313- else
314- Node := Self.Data[Node].R;
319+ SplitVal := Self.Data[Node].Split.Vector[Axis];
320+ Val := Value [Axis];
315321
316- Inc(Depth);
322+ if FloatEqual(Val, SplitVal) then
323+ begin
324+ Result := FindNode(Self.Data[Node].L, Depth + 1 );
325+ if Result = NONE then
326+ Result := FindNode(Self.Data[Node].R, Depth + 1 );
327+ end
328+ else if Val < SplitVal then
329+ Result := FindNode(Self.Data[Node].L, Depth + 1 )
330+ else
331+ Result := FindNode(Self.Data[Node].R, Depth + 1 );
317332 end ;
333+
334+ begin
335+ Result := FindNode(0 , 0 );
318336end ;
319337
320338
@@ -444,6 +462,121 @@ TNearestItem = record
444462 end ;
445463end ;
446464
465+ // same as the above, but returns an array of indices.
466+ function TKDTree.KNearestIndex (Vector: TSingleArray; K: Integer; NotEqual: Boolean = False): TIntegerArray;
467+ type
468+ TNearestItem = record
469+ Node: Integer;
470+ DistSq: Single;
471+ end ;
472+ TNearestHeap = array of TNearestItem;
473+
474+ var
475+ Heap: TNearestHeap;
476+
477+ procedure Heapify (Index: Integer);
478+ var
479+ Largest: Integer;
480+ Left, Right: Integer;
481+ Temp: TNearestItem;
482+ begin
483+ Largest := Index;
484+ Left := 2 * Index + 1 ;
485+ Right := 2 * Index + 2 ;
486+
487+ if (Left < Length(Heap)) and (Heap[Left].DistSq > Heap[Largest].DistSq) then
488+ Largest := Left;
489+
490+ if (Right < Length(Heap)) and (Heap[Right].DistSq > Heap[Largest].DistSq) then
491+ Largest := Right;
492+
493+ if Largest <> Index then
494+ begin
495+ Temp := Heap[Index];
496+ Heap[Index] := Heap[Largest];
497+ Heap[Largest] := Temp;
498+ Heapify(Largest);
499+ end ;
500+ end ;
501+
502+ procedure FindKNearest (Node: Integer; Depth: UInt8);
503+ var
504+ Delta, DistSq: Single;
505+ Test: Integer;
506+ This: PKDNode;
507+ Axis: Integer;
508+ Temp: TNearestItem;
509+ I: Integer;
510+ begin
511+ if Node = NONE then Exit;
512+
513+ This := @Self.Data[Node];
514+ Axis := Depth mod Self.Dimensions;
515+
516+ Delta := This^.Split.Vector[Axis] - Vector[Axis];
517+
518+ if Length(Heap) < K then
519+ DistSq := Self.SqDistance(This^.Split.Vector, Vector, High(UInt32)) // No limit if heap is not full
520+ else
521+ DistSq := Self.SqDistance(This^.Split.Vector, Vector, Heap[0 ].DistSq); // Limit is the furthest distance in the heap
522+
523+ if not ((DistSq = 0 ) and NotEqual) then
524+ begin
525+ if Length(Heap) < K then
526+ begin
527+ // heap not full, add current node
528+ SetLength(Heap, Length(Heap) + 1 );
529+ Heap[High(Heap)].Node := Node;
530+ Heap[High(Heap)].DistSq := DistSq;
531+
532+ // heapify upwards
533+ i := High(Heap);
534+ while (i > 0 ) and (Heap[(i - 1 ) div 2 ].DistSq < Heap[i].DistSq) do
535+ begin
536+ Temp := Heap[i];
537+ Heap[i] := Heap[(i - 1 ) div 2 ];
538+ Heap[(i - 1 ) div 2 ] := Temp;
539+ i := (i - 1 ) div 2 ;
540+ end ;
541+ end
542+ else
543+ begin
544+ if DistSq < Heap[0 ].DistSq then
545+ begin
546+ // replace the furthest node with the current node
547+ Heap[0 ].Node := Node;
548+ Heap[0 ].DistSq := DistSq;
549+ // heapify downwards
550+ Heapify(0 );
551+ end ;
552+ end ;
553+ end ;
554+
555+ if Delta > 0 then Test := This^.L else Test := This^.R;
556+ FindKNearest(Test, Depth + 1 );
557+
558+ if (Length(Heap) < K) or (Sqr(Delta) < Heap[0 ].DistSq) then
559+ begin
560+ if Delta > 0 then Test := This^.R else Test := This^.L;
561+ FindKNearest(Test, Depth + 1 );
562+ end ;
563+ end ;
564+
565+ var
566+ i,j: Integer;
567+ begin
568+ SetLength(Heap, 0 );
569+ FindKNearest(0 , 0 );
570+
571+ SetLength(Result, Length(Heap));
572+ j := 0 ;
573+ for i := High(Heap) downto 0 do
574+ begin
575+ Result[j] := Heap[i].Node;
576+ Inc(j);
577+ end ;
578+ end ;
579+
447580
448581(*
449582 Returns the vectors that within the hyperrectangle defined by low and high vectors.
0 commit comments