Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
9ff5d2d
Add mutual exclusivity and k range validation
mengweieric Apr 8, 2026
7f8ed14
Add LIMIT > k rejection in top-k vector search mode
mengweieric Apr 8, 2026
943cb89
Add comprehensive test coverage for vector search hardening
mengweieric Apr 8, 2026
767c7f0
Add resolver argument count edge case tests
mengweieric Apr 8, 2026
19f0af3
Add radial size policy, sort restriction, and integration tests
mengweieric Apr 8, 2026
21ee6dc
Fill test coverage gaps for vector search hardening
mengweieric Apr 8, 2026
534b5f4
Preserve sort.getCount() limit pushdown contract in pushDownSort
mengweieric Apr 9, 2026
0e02488
Add compound predicate and radial+WHERE test coverage
mengweieric Apr 9, 2026
e1325e8
Add FilterType enum for post|efficient filter placement
mengweieric Apr 9, 2026
dd6450f
Add filter_type to allowed option keys with post|efficient validation
mengweieric Apr 9, 2026
bfaabb4
Strip filter_type from options and pass as typed FilterType to Vector…
mengweieric Apr 9, 2026
4ab27ae
Collapse buildKnnQueryJson to accept optional filter clause
mengweieric Apr 9, 2026
c7c4130
Implement efficient filter pushdown branching and build-time validati…
mengweieric Apr 9, 2026
fdd1810
Wire FilterType and rebuild callback through createScanBuilder
mengweieric Apr 9, 2026
af6e2ba
Add build-time validation and regression tests for LIMIT/sort under e…
mengweieric Apr 9, 2026
5b0373c
Add integration tests for filter_type=post|efficient and spotless for…
mengweieric Apr 9, 2026
bea6607
Reject radial vector search without LIMIT
mengweieric Apr 10, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
431 changes: 431 additions & 0 deletions integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java

Large diffs are not rendered by default.

Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.sql.opensearch.storage;

import java.util.Arrays;
import java.util.Set;
import java.util.stream.Collectors;
import org.opensearch.sql.exception.ExpressionEvaluationException;

/** Filter placement strategy for vectorSearch() WHERE clauses. */
public enum FilterType {
/** WHERE placed in bool.filter outside the knn clause (post-filtering). */
POST("post"),

/** WHERE placed inside knn.filter for efficient pre-filtering. */
EFFICIENT("efficient");

private final String value;

FilterType(String value) {
this.value = value;
}

public String getValue() {
return value;
}

private static final Set<String> VALID_VALUES =
Arrays.stream(values()).map(FilterType::getValue).collect(Collectors.toSet());

public static FilterType fromString(String str) {
for (FilterType ft : values()) {
if (ft.value.equals(str)) {
return ft;
}
}
throw new ExpressionEvaluationException(
String.format("filter_type must be one of %s, got '%s'", VALID_VALUES, str));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,18 +26,32 @@ public class VectorSearchIndex extends OpenSearchIndex {
private final String field;
private final float[] vector;
private final Map<String, String> options;
private final FilterType filterType; // null means default (POST)

public VectorSearchIndex(
OpenSearchClient client,
Settings settings,
String indexName,
String field,
float[] vector,
Map<String, String> options) {
Map<String, String> options,
FilterType filterType) {
super(client, settings, indexName);
this.field = field;
this.vector = vector;
this.options = options;
this.filterType = filterType;
}

/** Backward-compatible constructor — defaults to no explicit filter type. */
public VectorSearchIndex(
OpenSearchClient client,
Settings settings,
String indexName,
String field,
float[] vector,
Map<String, String> options) {
this(client, settings, indexName, field, vector, options, null);
}

@Override
Expand All @@ -46,16 +60,32 @@ public TableScanBuilder createScanBuilder() {
getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE);
var requestBuilder = createRequestBuilder();

// Use VectorSearchQueryBuilder to keep knn in must (scoring) context.
// WHERE filters will be placed in filter (non-scoring) context.
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery());
// Callback for efficient filtering: serialize WHERE QueryBuilder to JSON,
// rebuild knn query with filter embedded. JSON handling stays in this class.
Function<QueryBuilder, QueryBuilder> rebuildWithFilter =
whereQuery -> new WrapperQueryBuilder(buildKnnQueryJson(whereQuery.toString()));

boolean filterTypeExplicit = filterType != null;
FilterType effectiveFilterType = filterType != null ? filterType : FilterType.POST;

var queryBuilder =
new VectorSearchQueryBuilder(
requestBuilder,
buildKnnQuery(),
options,
effectiveFilterType,
filterTypeExplicit,
rebuildWithFilter);
requestBuilder.pushDownTrackedScore(true);

// Top-k mode: default size to k so queries without LIMIT return k results
// instead of falling into the generic large-scan path.
// LIMIT pushdown will further reduce this if present.
// Default size policy: LIMIT pushdown will further reduce if present.
if (options.containsKey("k")) {
// Top-k mode: default size to k so queries without LIMIT return k results.
requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0);
} else {
// Radial mode (max_distance/min_score): cap at maxResultWindow.
// Without an explicit cap, radial queries could return unbounded results.
requestBuilder.pushDownLimitToRequestTotal(getMaxResultWindow(), 0);
}

Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator =
Expand All @@ -68,6 +98,20 @@ public TableScanBuilder createScanBuilder() {
}

private QueryBuilder buildKnnQuery() {
return new WrapperQueryBuilder(buildKnnQueryJson());
}

// Package-private for testing
String buildKnnQueryJson() {
return buildKnnQueryJson(null);
}

/**
* Builds knn query JSON, optionally embedding a filter clause for efficient filtering.
*
* @param filterJson serialized filter JSON to embed in knn.field.filter, or null for no filter
*/
String buildKnnQueryJson(String filterJson) {
StringBuilder vectorJson = new StringBuilder("[");
for (int i = 0; i < vector.length; i++) {
if (i > 0) vectorJson.append(",");
Expand All @@ -88,11 +132,14 @@ private QueryBuilder buildKnnQuery() {
}
}

String knnQueryJson =
String.format(
"{\"knn\":{\"%s\":{\"vector\":%s%s}}}",
field, vectorJson.toString(), optionsJson.toString());
return new WrapperQueryBuilder(knnQueryJson);
String filterClause = "";
if (filterJson != null) {
filterClause = String.format(",\"filter\":%s", filterJson);
}

return String.format(
"{\"knn\":{\"%s\":{\"vector\":%s%s%s}}}",
field, vectorJson.toString(), optionsJson.toString(), filterClause);
}

private static boolean isNumeric(String str) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ public class VectorSearchTableFunctionImplementation extends FunctionExpression
implements TableFunctionImplementation {

/** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */
static final Set<String> ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score");
static final Set<String> ALLOWED_OPTION_KEYS =
Set.of("k", "max_distance", "min_score", "filter_type");

/**
* Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores,
Expand Down Expand Up @@ -99,7 +100,14 @@ public Table applyArguments() {
Map<String, String> options = parseOptions(optionStr);
validateOptions(options);

return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options);
// Strip filter_type — it's a SQL-layer directive, not a knn parameter
FilterType filterType = null;
if (options.containsKey("filter_type")) {
filterType = FilterType.fromString(options.remove("filter_type"));
}

return new VectorSearchIndex(
client, settings, tableName, fieldName, vector, options, filterType);
}

private float[] parseVector(String vectorLiteral) {
Expand Down Expand Up @@ -190,16 +198,31 @@ private void validateOptions(Map<String, String> options) {
String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS));
}
}
if (options.containsKey("filter_type")) {
// Validate early — fromString throws if invalid
FilterType.fromString(options.get("filter_type"));
}
boolean hasK = options.containsKey("k");
boolean hasMaxDistance = options.containsKey("max_distance");
boolean hasMinScore = options.containsKey("min_score");
if (!hasK && !hasMaxDistance && !hasMinScore) {
throw new ExpressionEvaluationException(
"Missing required option: one of k, max_distance, or min_score");
}
// Mutual exclusivity: exactly one search mode allowed
int modeCount = (hasK ? 1 : 0) + (hasMaxDistance ? 1 : 0) + (hasMinScore ? 1 : 0);
if (modeCount > 1) {
throw new ExpressionEvaluationException(
"Only one of k, max_distance, or min_score may be specified");
}
// Parse and canonicalize numeric values — closes JSON injection via option values
if (hasK) {
parseIntOption(options, "k");
int k = Integer.parseInt(options.get("k"));
if (k < 1 || k > 10000) {
throw new ExpressionEvaluationException(
String.format("k must be between 1 and 10000, got %d", k));
}
}
if (hasMaxDistance) {
parseDoubleOption(options, "max_distance");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,43 +5,140 @@

package org.opensearch.sql.opensearch.storage.scan;

import java.util.Map;
import java.util.function.Function;
import org.apache.commons.lang3.tuple.Pair;
import org.opensearch.index.query.BoolQueryBuilder;
import org.opensearch.index.query.QueryBuilder;
import org.opensearch.index.query.QueryBuilders;
import org.opensearch.sql.ast.tree.Sort;
import org.opensearch.sql.ast.tree.Sort.SortOption;
import org.opensearch.sql.exception.ExpressionEvaluationException;
import org.opensearch.sql.expression.Expression;
import org.opensearch.sql.expression.ReferenceExpression;
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
import org.opensearch.sql.opensearch.storage.FilterType;
import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder;
import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer;
import org.opensearch.sql.planner.logical.LogicalFilter;
import org.opensearch.sql.planner.logical.LogicalLimit;
import org.opensearch.sql.planner.logical.LogicalSort;

/**
* Query builder for vector search that keeps the knn query in a scoring (must) context and puts
* WHERE filters in a non-scoring (filter) context. This prevents the knn relevance scores from
* being destroyed when a WHERE clause is pushed down.
*
* <p>Without this, the default pushDownFilter wraps both queries into bool.filter, which is a
* non-scoring context.
* <p>Supports two filter placement strategies via {@link FilterType}:
*
* <ul>
* <li>{@code POST} — WHERE in {@code bool.filter} outside knn (post-filtering, default)
* <li>{@code EFFICIENT} — WHERE inside {@code knn.filter} for pre-filtering during ANN search
* </ul>
*/
public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder {

private final QueryBuilder knnQuery;
private final Map<String, String> options;
private final FilterType filterType;
private final boolean filterTypeExplicit;
private final Function<QueryBuilder, QueryBuilder> rebuildKnnWithFilter;
private boolean filterPushed = false;
private boolean limitPushed = false;

public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) {
/** Full constructor with filter type support. */
public VectorSearchQueryBuilder(
OpenSearchRequestBuilder requestBuilder,
QueryBuilder knnQuery,
Map<String, String> options,
FilterType filterType,
boolean filterTypeExplicit,
Function<QueryBuilder, QueryBuilder> rebuildKnnWithFilter) {
super(requestBuilder);
// Set knn as the initial query (scoring context)
requestBuilder.getSourceBuilder().query(knnQuery);
this.knnQuery = knnQuery;
this.options = options;
this.filterType = filterType != null ? filterType : FilterType.POST;
this.filterTypeExplicit = filterTypeExplicit;
this.rebuildKnnWithFilter = rebuildKnnWithFilter;
}

/** Backward-compatible constructor — defaults to POST, not explicit. */
public VectorSearchQueryBuilder(
OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map<String, String> options) {
this(requestBuilder, knnQuery, options, FilterType.POST, false, null);
}

@Override
public boolean pushDownFilter(LogicalFilter filter) {
FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer());
Expression queryCondition = filter.getCondition();
QueryBuilder whereQuery = queryBuilder.build(queryCondition);
filterPushed = true;

// Combine: knn in must (scores), WHERE in filter (no scoring impact)
BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery);
requestBuilder.getSourceBuilder().query(combined);
if (filterType == FilterType.EFFICIENT) {
QueryBuilder rebuiltKnn = rebuildKnnWithFilter.apply(whereQuery);
requestBuilder.getSourceBuilder().query(rebuiltKnn);
} else {
// POST mode: knn in must (scores), WHERE in filter (no scoring impact)
BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery);
requestBuilder.getSourceBuilder().query(combined);
}
return true;
}

@Override
public boolean pushDownLimit(LogicalLimit limit) {
if (options.containsKey("k")) {
int k = Integer.parseInt(options.get("k"));
if (limit.getLimit() > k) {
throw new ExpressionEvaluationException(
String.format("LIMIT %d exceeds k=%d in top-k vector search", limit.getLimit(), k));
}
}
limitPushed = true;
return super.pushDownLimit(limit);
}

@Override
public boolean pushDownSort(LogicalSort sort) {
// Vector search returns results sorted by _score DESC by default.
// Only _score DESC is meaningful; reject all other sort expressions.
for (Pair<SortOption, Expression> sortItem : sort.getSortList()) {
Expression expr = sortItem.getRight();
if (!(expr instanceof ReferenceExpression)
|| !"_score".equals(((ReferenceExpression) expr).getAttr())) {
throw new ExpressionEvaluationException(
String.format(
"vectorSearch only supports ORDER BY _score DESC; "
+ "unsupported sort expression: %s",
expr));
}
if (sortItem.getLeft().getSortOrder() != Sort.SortOrder.DESC) {
throw new ExpressionEvaluationException(
"vectorSearch only supports ORDER BY _score DESC; _score ASC is not supported");
}
}
// _score DESC is the natural knn order — no need to push the sort itself to OpenSearch.
// Preserve the parent's sort.getCount() → limit pushdown contract: SQL always sets count=0,
// but PPL or future callers may set a non-zero count to combine sort+limit in one node.
if (sort.getCount() != 0) {
requestBuilder.pushDownLimit(sort.getCount(), 0);
}
return true;
}

@Override
public OpenSearchRequestBuilder build() {
if (filterTypeExplicit && !filterPushed) {
throw new ExpressionEvaluationException("filter_type requires a pushdownable WHERE clause");
}
boolean isRadial = !options.containsKey("k");
if (isRadial && !limitPushed) {
throw new ExpressionEvaluationException(
"LIMIT is required for radial vector search (max_distance or min_score)."
+ " Without LIMIT, the result set size is unbounded.");
}
return super.build();
}
}
Loading
Loading