Skip to content

Commit 354635c

Browse files
committed
Add LIMIT > k rejection in top-k vector search mode
VectorSearchQueryBuilder now accepts options map and rejects pushDownLimit when LIMIT exceeds k. Radial modes (max_distance, min_score) have no LIMIT restriction. Signed-off-by: Eric Wei <mengwei.eric@gmail.com>
1 parent 23474a7 commit 354635c

3 files changed

Lines changed: 90 additions & 5 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public TableScanBuilder createScanBuilder() {
4848

4949
// Use VectorSearchQueryBuilder to keep knn in must (scoring) context.
5050
// WHERE filters will be placed in filter (non-scoring) context.
51-
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery());
51+
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery(), options);
5252
requestBuilder.pushDownTrackedScore(true);
5353

5454
// Top-k mode: default size to k so queries without LIMIT return k results

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,18 @@
55

66
package org.opensearch.sql.opensearch.storage.scan;
77

8+
import java.util.Map;
89
import org.opensearch.index.query.BoolQueryBuilder;
910
import org.opensearch.index.query.QueryBuilder;
1011
import org.opensearch.index.query.QueryBuilders;
12+
import org.opensearch.sql.exception.ExpressionEvaluationException;
1113
import org.opensearch.sql.expression.Expression;
1214
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
1315
import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder;
1416
import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer;
1517
import org.opensearch.sql.planner.logical.LogicalFilter;
18+
import org.opensearch.sql.planner.logical.LogicalLimit;
19+
import org.opensearch.sql.planner.logical.LogicalSort;
1620

1721
/**
1822
* Query builder for vector search that keeps the knn query in a scoring (must) context and puts
@@ -25,12 +29,14 @@
2529
public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder {
2630

2731
private final QueryBuilder knnQuery;
32+
private final Map<String, String> options;
2833

29-
public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) {
34+
public VectorSearchQueryBuilder(
35+
OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map<String, String> options) {
3036
super(requestBuilder);
31-
// Set knn as the initial query (scoring context)
3237
requestBuilder.getSourceBuilder().query(knnQuery);
3338
this.knnQuery = knnQuery;
39+
this.options = options;
3440
}
3541

3642
@Override
@@ -44,4 +50,25 @@ public boolean pushDownFilter(LogicalFilter filter) {
4450
requestBuilder.getSourceBuilder().query(combined);
4551
return true;
4652
}
53+
54+
@Override
55+
public boolean pushDownLimit(LogicalLimit limit) {
56+
if (options.containsKey("k")) {
57+
int k = Integer.parseInt(options.get("k"));
58+
if (limit.getLimit() > k) {
59+
throw new ExpressionEvaluationException(
60+
String.format("LIMIT %d exceeds k=%d in top-k vector search", limit.getLimit(), k));
61+
}
62+
}
63+
return super.pushDownLimit(limit);
64+
}
65+
66+
@Override
67+
public boolean pushDownSort(LogicalSort sort) {
68+
// Vector search returns results sorted by _score DESC by default.
69+
// Reject non-trivial sort pushdowns — only _score DESC is meaningful.
70+
// For now, let the parent handle it; unsupported sort rejection is
71+
// deferred until we can inspect the sort expression for _score references.
72+
return super.pushDownSort(sort);
73+
}
4774
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,25 @@
66
package org.opensearch.sql.opensearch.storage.scan;
77

88
import static org.junit.jupiter.api.Assertions.assertEquals;
9+
import static org.junit.jupiter.api.Assertions.assertThrows;
910
import static org.junit.jupiter.api.Assertions.assertTrue;
1011
import static org.mockito.Mockito.mock;
1112
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
1213

1314
import java.util.Collections;
15+
import java.util.Map;
1416
import org.junit.jupiter.api.Test;
1517
import org.opensearch.index.query.BoolQueryBuilder;
1618
import org.opensearch.index.query.QueryBuilder;
1719
import org.opensearch.index.query.WrapperQueryBuilder;
1820
import org.opensearch.sql.common.setting.Settings;
21+
import org.opensearch.sql.exception.ExpressionEvaluationException;
1922
import org.opensearch.sql.expression.DSL;
2023
import org.opensearch.sql.expression.ReferenceExpression;
2124
import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory;
2225
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
2326
import org.opensearch.sql.planner.logical.LogicalFilter;
27+
import org.opensearch.sql.planner.logical.LogicalLimit;
2428
import org.opensearch.sql.planner.logical.LogicalValues;
2529

2630
class VectorSearchQueryBuilderTest {
@@ -30,7 +34,7 @@ void knnQuerySetAsScoringQuery() {
3034
var requestBuilder = createRequestBuilder();
3135
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
3236

33-
new VectorSearchQueryBuilder(requestBuilder, knnQuery);
37+
new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5"));
3438

3539
QueryBuilder query = requestBuilder.getSourceBuilder().query();
3640
assertTrue(
@@ -42,7 +46,7 @@ void knnQuerySetAsScoringQuery() {
4246
void pushDownFilterKeepsKnnInScoringContext() {
4347
var requestBuilder = createRequestBuilder();
4448
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
45-
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery);
49+
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5"));
4650

4751
// Simulate WHERE name = 'John'
4852
var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John"));
@@ -62,6 +66,60 @@ void pushDownFilterKeepsKnnInScoringContext() {
6266
"must clause should contain the original knn WrapperQueryBuilder");
6367
}
6468

69+
@Test
70+
void pushDownLimitWithinKSucceeds() {
71+
var requestBuilder = createRequestBuilder();
72+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
73+
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5"));
74+
75+
var dummyChild = new LogicalValues(Collections.emptyList());
76+
var limit = new LogicalLimit(dummyChild, 3, 0);
77+
78+
boolean pushed = builder.pushDownLimit(limit);
79+
assertTrue(pushed, "LIMIT within k should succeed");
80+
}
81+
82+
@Test
83+
void pushDownLimitExceedingKThrows() {
84+
var requestBuilder = createRequestBuilder();
85+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
86+
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5"));
87+
88+
var dummyChild = new LogicalValues(Collections.emptyList());
89+
var limit = new LogicalLimit(dummyChild, 10, 0);
90+
91+
ExpressionEvaluationException ex =
92+
assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit));
93+
assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5"));
94+
}
95+
96+
@Test
97+
void pushDownLimitEqualToKSucceeds() {
98+
var requestBuilder = createRequestBuilder();
99+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
100+
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5"));
101+
102+
var dummyChild = new LogicalValues(Collections.emptyList());
103+
var limit = new LogicalLimit(dummyChild, 5, 0);
104+
105+
boolean pushed = builder.pushDownLimit(limit);
106+
assertTrue(pushed, "LIMIT equal to k should succeed");
107+
}
108+
109+
@Test
110+
void pushDownLimitRadialModeNoRestriction() {
111+
var requestBuilder = createRequestBuilder();
112+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
113+
var builder =
114+
new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0"));
115+
116+
var dummyChild = new LogicalValues(Collections.emptyList());
117+
var limit = new LogicalLimit(dummyChild, 100, 0);
118+
119+
boolean pushed = builder.pushDownLimit(limit);
120+
assertTrue(pushed, "Radial mode should not restrict LIMIT");
121+
}
122+
65123
private OpenSearchRequestBuilder createRequestBuilder() {
66124
return new OpenSearchRequestBuilder(
67125
mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class));

0 commit comments

Comments
 (0)