Skip to content

Commit e51d04b

Browse files
committed
Add intersection tree implementation and tests
1 parent ebbb9ff commit e51d04b

2 files changed

Lines changed: 827 additions & 58 deletions

File tree

Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
'''
2+
Implementation of an intersection tree approach for efficiently finding
3+
intersections among a set of query intervals and a database intervals.
4+
'''
5+
6+
import random
7+
import typing
8+
9+
10+
Interval: typing.TypeAlias = tuple[int, int]
11+
Queries: typing.TypeAlias = list[Interval]
12+
QueryResult: typing.TypeAlias = list[tuple[Interval, Interval]]
13+
14+
class Node:
15+
'''Class representting a node in an intersection tree.
16+
17+
Each node represents an interval in the dataset, so it has the start and the
18+
end of that interval. It also contains a reference to its left and right child.
19+
For indexing purposes, it also contains the maximum end value over all its children.
20+
'''
21+
22+
_start: int
23+
_end: int
24+
_max_end: int
25+
_left: 'Node' = None
26+
_right: 'Node' = None
27+
28+
def __init__(self, interval: Interval) -> None:
29+
'''Initialize node representing the interval [start, end).
30+
31+
Parameters
32+
----------
33+
interval: Interval
34+
the interval represented by this node
35+
'''
36+
self._start = interval[0]
37+
self._end = interval[1]
38+
self._max_end = interval[1]
39+
40+
def insert(self, interval: Interval) -> None:
41+
'''Insert a new interval [start, end) in the tree.
42+
43+
Parameters
44+
----------
45+
interval: Interval
46+
the interval to insert
47+
'''
48+
if interval[0] < self._start:
49+
if self._left is None:
50+
self._left = Node(interval)
51+
else:
52+
self._left.insert(interval)
53+
else:
54+
if self._right is None:
55+
self._right = Node(interval)
56+
else:
57+
self._right.insert(interval)
58+
self._max_end = max(self._max_end, interval[1])
59+
60+
def search(self, interval: Interval, results: list[Interval]) -> None:
61+
'''Search for all intervals in the tree that intersect with [start, end)
62+
and append them to results.
63+
64+
Parameters
65+
----------
66+
interval: Interval
67+
the interval to search for intersections
68+
results: list[Interval]
69+
list to append the results to
70+
'''
71+
if self._start < interval[1] and interval[0] < self._end:
72+
results.append((self._start, self._end))
73+
if self._left is not None and self._left._max_end >= interval[0]:
74+
self._left.search(interval, results)
75+
if self._right is not None and self._start <= interval[1]:
76+
self._right.search(interval, results)
77+
78+
def to_str(self, prefix: str = '') -> str:
79+
'''Return a string representation of the tree.
80+
81+
Parameters
82+
----------
83+
prefix: str
84+
prefix to add to each line, default is empty string
85+
86+
Returns
87+
-------
88+
str
89+
string representation of the tree
90+
'''
91+
result = f'{prefix}[{self._start}, {self._end}] (max_end={self._max_end})\n'
92+
if self._left is not None:
93+
result += self._left.to_str(prefix + ' ')
94+
if self._right is not None:
95+
result += self._right.to_str(prefix + ' ')
96+
return result
97+
98+
def __repr__(self) -> str:
99+
'''Return a string representation of the node.
100+
101+
Returns
102+
-------
103+
str
104+
string representation of the node
105+
'''
106+
return f'Node({self._start}, {self._end}, max_end={self._max_end})'
107+
108+
def __str__(self) -> str:
109+
'''Return a string representation of the tree.
110+
111+
Returns
112+
-------
113+
str
114+
string representation of the tree
115+
'''
116+
return self.to_str()
117+
118+
119+
def generate_interval(max_end: int = 1_000_000_000) -> Interval:
120+
'''Generate a half-open interval of at least length 1
121+
122+
Parameters
123+
----------
124+
max_end: int
125+
largest end value of the interval, default value 1_000_000_000
126+
127+
Returns
128+
-------
129+
Interval
130+
Tuple (start, end) such that end - start > 1
131+
'''
132+
start = random.randint(0, max_end - 2)
133+
end = random.randint(start + 2, max_end)
134+
return start, end
135+
136+
137+
def create_db(size: int, max_end: int = 1_000_000) -> Node:
138+
'''Generate a database of intervals and return the intersection tree.
139+
140+
Parameters
141+
----------
142+
size: int
143+
number of intervals in the database
144+
max_end: int
145+
largest end value of the interval, default value 1_000_000_000
146+
147+
Returns
148+
-------
149+
Node
150+
root of the intersection tree
151+
'''
152+
tree = Node(generate_interval(max_end))
153+
for _ in range(1, size):
154+
tree.insert(generate_interval(max_end))
155+
return tree
156+
157+
158+
def execute_queries(queries: Queries, db: Node) -> QueryResult:
159+
'''Execute the query on the database
160+
161+
Parameters
162+
----------
163+
queries: Queries
164+
queries to be executed
165+
db: Db
166+
database to query
167+
168+
Returns
169+
-------
170+
QueryResult
171+
set of tuples of query and database intervals that intersect
172+
'''
173+
results: QueryResult = []
174+
for q in queries:
175+
db_results: list[Interval] = []
176+
db.search(q, db_results)
177+
results.extend((q, d) for d in db_results)
178+
return results
179+
180+
181+
def create_queries(size: int = 1_000, max_end: int = 1_000_000) -> Queries:
182+
'''Generate query intervals.
183+
184+
Parameters
185+
----------
186+
size: int
187+
number of intervals in the query, default value 1_000
188+
max_end: int
189+
largest end value of the interval, default value 1_000_000_000
190+
191+
Returns
192+
-------
193+
Queries
194+
a list of half-open intervals
195+
'''
196+
return [generate_interval(max_end) for _ in range(size)]

0 commit comments

Comments
 (0)