|
6 | 6 | package org.opensearch.sql.opensearch.storage.scan; |
7 | 7 |
|
8 | 8 | import java.util.Map; |
| 9 | +import java.util.function.Function; |
9 | 10 | import org.apache.commons.lang3.tuple.Pair; |
10 | 11 | import org.opensearch.index.query.BoolQueryBuilder; |
11 | 12 | import org.opensearch.index.query.QueryBuilder; |
|
16 | 17 | import org.opensearch.sql.expression.Expression; |
17 | 18 | import org.opensearch.sql.expression.ReferenceExpression; |
18 | 19 | import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; |
| 20 | +import org.opensearch.sql.opensearch.storage.FilterType; |
19 | 21 | import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; |
20 | 22 | import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; |
21 | 23 | import org.opensearch.sql.planner.logical.LogicalFilter; |
|
27 | 29 | * WHERE filters in a non-scoring (filter) context. This prevents the knn relevance scores from |
28 | 30 | * being destroyed when a WHERE clause is pushed down. |
29 | 31 | * |
30 | | - * <p>Without this, the default pushDownFilter wraps both queries into bool.filter, which is a |
31 | | - * non-scoring context. |
| 32 | + * <p>Supports two filter placement strategies via {@link FilterType}: |
| 33 | + * |
| 34 | + * <ul> |
| 35 | + * <li>{@code POST} — WHERE in {@code bool.filter} outside knn (post-filtering, default) |
| 36 | + * <li>{@code EFFICIENT} — WHERE inside {@code knn.filter} for pre-filtering during ANN search |
| 37 | + * </ul> |
32 | 38 | */ |
33 | 39 | public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { |
34 | 40 |
|
35 | 41 | private final QueryBuilder knnQuery; |
36 | 42 | private final Map<String, String> options; |
| 43 | + private final FilterType filterType; |
| 44 | + private final boolean filterTypeExplicit; |
| 45 | + private final Function<QueryBuilder, QueryBuilder> rebuildKnnWithFilter; |
| 46 | + private boolean filterPushed = false; |
37 | 47 |
|
| 48 | + /** Full constructor with filter type support. */ |
38 | 49 | public VectorSearchQueryBuilder( |
39 | | - OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map<String, String> options) { |
| 50 | + OpenSearchRequestBuilder requestBuilder, |
| 51 | + QueryBuilder knnQuery, |
| 52 | + Map<String, String> options, |
| 53 | + FilterType filterType, |
| 54 | + boolean filterTypeExplicit, |
| 55 | + Function<QueryBuilder, QueryBuilder> rebuildKnnWithFilter) { |
40 | 56 | super(requestBuilder); |
41 | 57 | requestBuilder.getSourceBuilder().query(knnQuery); |
42 | 58 | this.knnQuery = knnQuery; |
43 | 59 | this.options = options; |
| 60 | + this.filterType = filterType != null ? filterType : FilterType.POST; |
| 61 | + this.filterTypeExplicit = filterTypeExplicit; |
| 62 | + this.rebuildKnnWithFilter = rebuildKnnWithFilter; |
| 63 | + } |
| 64 | + |
| 65 | + /** Backward-compatible constructor — defaults to POST, not explicit. */ |
| 66 | + public VectorSearchQueryBuilder( |
| 67 | + OpenSearchRequestBuilder requestBuilder, |
| 68 | + QueryBuilder knnQuery, |
| 69 | + Map<String, String> options) { |
| 70 | + this(requestBuilder, knnQuery, options, FilterType.POST, false, null); |
44 | 71 | } |
45 | 72 |
|
46 | 73 | @Override |
47 | 74 | public boolean pushDownFilter(LogicalFilter filter) { |
48 | 75 | FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); |
49 | 76 | Expression queryCondition = filter.getCondition(); |
50 | 77 | QueryBuilder whereQuery = queryBuilder.build(queryCondition); |
| 78 | + filterPushed = true; |
51 | 79 |
|
52 | | - // Combine: knn in must (scores), WHERE in filter (no scoring impact) |
53 | | - BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); |
54 | | - requestBuilder.getSourceBuilder().query(combined); |
| 80 | + if (filterType == FilterType.EFFICIENT) { |
| 81 | + QueryBuilder rebuiltKnn = rebuildKnnWithFilter.apply(whereQuery); |
| 82 | + requestBuilder.getSourceBuilder().query(rebuiltKnn); |
| 83 | + } else { |
| 84 | + // POST mode: knn in must (scores), WHERE in filter (no scoring impact) |
| 85 | + BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); |
| 86 | + requestBuilder.getSourceBuilder().query(combined); |
| 87 | + } |
55 | 88 | return true; |
56 | 89 | } |
57 | 90 |
|
@@ -94,4 +127,13 @@ public boolean pushDownSort(LogicalSort sort) { |
94 | 127 | } |
95 | 128 | return true; |
96 | 129 | } |
| 130 | + |
| 131 | + @Override |
| 132 | + public OpenSearchRequestBuilder build() { |
| 133 | + if (filterTypeExplicit && !filterPushed) { |
| 134 | + throw new ExpressionEvaluationException( |
| 135 | + "filter_type requires a pushdownable WHERE clause"); |
| 136 | + } |
| 137 | + return super.build(); |
| 138 | + } |
97 | 139 | } |
0 commit comments