diff --git a/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java new file mode 100644 index 00000000000..408d734e138 --- /dev/null +++ b/integ-test/src/test/java/org/opensearch/sql/sql/VectorSearchIT.java @@ -0,0 +1,431 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.sql; + +import static org.hamcrest.Matchers.containsString; + +import java.io.IOException; +import org.junit.Test; +import org.opensearch.client.ResponseException; +import org.opensearch.sql.legacy.SQLIntegTestCase; +import org.opensearch.sql.legacy.TestsConstants; + +/** + * Integration tests for vectorSearch SQL table function. These tests verify DSL push-down shape via + * _explain and validation error paths. They do NOT require the k-NN plugin since _explain only + * parses and plans the query without executing it against a knn index. + */ +public class VectorSearchIT extends SQLIntegTestCase { + + @Override + protected void init() throws Exception { + // _explain needs the index to exist for field resolution. + loadIndex(Index.ACCOUNT); + } + + private static final String TEST_INDEX = TestsConstants.TEST_INDEX_ACCOUNT; + + // ── DSL shape verification via _explain ─────────────────────────────── + + @Test + public void testExplainTopKProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=5') AS v " + + "LIMIT 5"); + + // WrapperQueryBuilder wraps the knn JSON — verify the wrapper is present + // and track_scores is enabled for score preservation. + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + assertTrue( + "Explain should contain track_scores:\n" + explain, explain.contains("track_scores")); + } + + @Test + public void testExplainRadialMaxDistanceProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "LIMIT 100"); + + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + } + + @Test + public void testExplainRadialMinScoreProducesKnnQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='min_score=0.8') AS v " + + "LIMIT 100"); + + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + } + + @Test + public void testExplainPostFilterProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + // ── Validation error paths ──────────────────────────────────────────── + + @Test + public void testMutualExclusivityRejectsKAndMaxDistance() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,max_distance=10') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testMutualExclusivityRejectsKAndMinScore() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,min_score=0.5') AS v")); + + assertThat(ex.getMessage(), containsString("Only one of")); + } + + @Test + public void testKTooLargeRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=10001') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testKZeroRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=0') AS v")); + + assertThat(ex.getMessage(), containsString("k must be between 1 and 10000")); + } + + @Test + public void testUnknownOptionKeyRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,method.ef_search=100') AS v")); + + assertThat(ex.getMessage(), containsString("Unknown option key")); + } + + @Test + public void testEmptyVectorRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("must not be empty")); + } + + @Test + public void testInvalidFieldNameRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', " + + "field='field\\\"injection', vector='[1.0]', option='k=5') AS v")); + + assertThat(ex.getMessage(), containsString("Invalid field name")); + } + + @Test + public void testMissingRequiredOptionRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='') AS v")); + + assertThat(ex.getMessage(), containsString("Missing required option")); + } + + @Test + public void testRadialWithoutLimitRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v")); + + assertThat(ex.getMessage(), containsString("LIMIT is required for radial vector search")); + } + + // ── Sort restriction validation ───────────────────────────────────────── + + @Test + public void testOrderByScoreDescExplainSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score DESC " + + "LIMIT 5"); + + assertTrue( + "Explain should succeed with ORDER BY _score DESC:\n" + explain, + explain.contains("wrapper")); + } + + @Test + public void testOrderByNonScoreFieldRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v.firstname ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("unsupported sort expression")); + } + + @Test + public void testOrderByScoreAscRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5') AS v " + + "ORDER BY v._score ASC " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("_score ASC is not supported")); + } + + // ── Compound predicate and radial + WHERE ─────────────────────────────── + + @Test + public void testExplainCompoundPredicateProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10') AS v " + + "WHERE v.state = 'TX' AND v.age > 30 " + + "LIMIT 10"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (compound WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + @Test + public void testExplainRadialWithWhereProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='max_distance=10.5') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 100"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue( + "Explain should contain must clause (knn in scoring context):\n" + explain, + explain.contains("must")); + assertTrue( + "Explain should contain filter clause (WHERE in non-scoring context):\n" + explain, + explain.contains("filter")); + } + + // ── LIMIT validation ─────────────────────────────────────────────────── + + @Test + public void testExplainLimitWithinKSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=10') AS v " + + "LIMIT 5"); + + assertTrue("Explain should succeed with LIMIT <= k:\n" + explain, explain.contains("wrapper")); + } + + // ── filter_type validation and explain ───────────────────────────── + + @Test + public void testExplainFilterTypePostProducesBoolQuery() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0, 3.0]', option='k=10,filter_type=post') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 10"); + + assertTrue("Explain should contain bool query:\n" + explain, explain.contains("bool")); + assertTrue("Explain should contain must:\n" + explain, explain.contains("must")); + assertTrue("Explain should contain filter:\n" + explain, explain.contains("filter")); + } + + @Test + public void testExplainFilterTypeEfficientProducesKnnWithFilter() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "WHERE v.state = 'TX' " + + "LIMIT 5"); + + // Efficient mode: knn rebuilt with filter inside, wrapped in WrapperQueryBuilder + assertTrue("Explain should contain wrapper query:\n" + explain, explain.contains("wrapper")); + } + + @Test + public void testFilterTypeEfficientWithoutWhereRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("filter_type requires a pushdownable WHERE clause")); + } + + @Test + public void testFilterTypePostWithoutWhereRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=post') AS v " + + "LIMIT 5")); + + assertThat(ex.getMessage(), containsString("filter_type requires a pushdownable WHERE clause")); + } + + @Test + public void testInvalidFilterTypeRejects() throws IOException { + ResponseException ex = + expectThrows( + ResponseException.class, + () -> + executeQuery( + "SELECT v._id FROM vectorSearch(table='t', field='f', " + + "vector='[1.0]', option='k=5,filter_type=bogus') AS v")); + + assertThat(ex.getMessage(), containsString("filter_type must be one of")); + } + + @Test + public void testEfficientFilterWithOrderByScoreDescSucceeds() throws IOException { + String explain = + explainQuery( + "SELECT v._id, v._score " + + "FROM vectorSearch(table='" + + TEST_INDEX + + "', field='embedding', " + + "vector='[1.0, 2.0]', option='k=5,filter_type=efficient') AS v " + + "WHERE v.state = 'TX' " + + "ORDER BY v._score DESC " + + "LIMIT 5"); + + assertTrue( + "Explain should succeed with efficient + ORDER BY _score DESC:\n" + explain, + explain.contains("wrapper")); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java new file mode 100644 index 00000000000..cc42bb35f58 --- /dev/null +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/FilterType.java @@ -0,0 +1,43 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import java.util.Arrays; +import java.util.Set; +import java.util.stream.Collectors; +import org.opensearch.sql.exception.ExpressionEvaluationException; + +/** Filter placement strategy for vectorSearch() WHERE clauses. */ +public enum FilterType { + /** WHERE placed in bool.filter outside the knn clause (post-filtering). */ + POST("post"), + + /** WHERE placed inside knn.filter for efficient pre-filtering. */ + EFFICIENT("efficient"); + + private final String value; + + FilterType(String value) { + this.value = value; + } + + public String getValue() { + return value; + } + + private static final Set VALID_VALUES = + Arrays.stream(values()).map(FilterType::getValue).collect(Collectors.toSet()); + + public static FilterType fromString(String str) { + for (FilterType ft : values()) { + if (ft.value.equals(str)) { + return ft; + } + } + throw new ExpressionEvaluationException( + String.format("filter_type must be one of %s, got '%s'", VALID_VALUES, str)); + } +} diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java index f33d5f2fa73..4ab0e2861cb 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java @@ -26,6 +26,7 @@ public class VectorSearchIndex extends OpenSearchIndex { private final String field; private final float[] vector; private final Map options; + private final FilterType filterType; // null means default (POST) public VectorSearchIndex( OpenSearchClient client, @@ -33,11 +34,24 @@ public VectorSearchIndex( String indexName, String field, float[] vector, - Map options) { + Map options, + FilterType filterType) { super(client, settings, indexName); this.field = field; this.vector = vector; this.options = options; + this.filterType = filterType; + } + + /** Backward-compatible constructor — defaults to no explicit filter type. */ + public VectorSearchIndex( + OpenSearchClient client, + Settings settings, + String indexName, + String field, + float[] vector, + Map options) { + this(client, settings, indexName, field, vector, options, null); } @Override @@ -46,16 +60,32 @@ public TableScanBuilder createScanBuilder() { getSettings().getSettingValue(Settings.Key.SQL_CURSOR_KEEP_ALIVE); var requestBuilder = createRequestBuilder(); - // Use VectorSearchQueryBuilder to keep knn in must (scoring) context. - // WHERE filters will be placed in filter (non-scoring) context. - var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery()); + // Callback for efficient filtering: serialize WHERE QueryBuilder to JSON, + // rebuild knn query with filter embedded. JSON handling stays in this class. + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder(buildKnnQueryJson(whereQuery.toString())); + + boolean filterTypeExplicit = filterType != null; + FilterType effectiveFilterType = filterType != null ? filterType : FilterType.POST; + + var queryBuilder = + new VectorSearchQueryBuilder( + requestBuilder, + buildKnnQuery(), + options, + effectiveFilterType, + filterTypeExplicit, + rebuildWithFilter); requestBuilder.pushDownTrackedScore(true); - // Top-k mode: default size to k so queries without LIMIT return k results - // instead of falling into the generic large-scan path. - // LIMIT pushdown will further reduce this if present. + // Default size policy: LIMIT pushdown will further reduce if present. if (options.containsKey("k")) { + // Top-k mode: default size to k so queries without LIMIT return k results. requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0); + } else { + // Radial mode (max_distance/min_score): cap at maxResultWindow. + // Without an explicit cap, radial queries could return unbounded results. + requestBuilder.pushDownLimitToRequestTotal(getMaxResultWindow(), 0); } Function createScanOperator = @@ -68,6 +98,20 @@ public TableScanBuilder createScanBuilder() { } private QueryBuilder buildKnnQuery() { + return new WrapperQueryBuilder(buildKnnQueryJson()); + } + + // Package-private for testing + String buildKnnQueryJson() { + return buildKnnQueryJson(null); + } + + /** + * Builds knn query JSON, optionally embedding a filter clause for efficient filtering. + * + * @param filterJson serialized filter JSON to embed in knn.field.filter, or null for no filter + */ + String buildKnnQueryJson(String filterJson) { StringBuilder vectorJson = new StringBuilder("["); for (int i = 0; i < vector.length; i++) { if (i > 0) vectorJson.append(","); @@ -88,11 +132,14 @@ private QueryBuilder buildKnnQuery() { } } - String knnQueryJson = - String.format( - "{\"knn\":{\"%s\":{\"vector\":%s%s}}}", - field, vectorJson.toString(), optionsJson.toString()); - return new WrapperQueryBuilder(knnQueryJson); + String filterClause = ""; + if (filterJson != null) { + filterClause = String.format(",\"filter\":%s", filterJson); + } + + return String.format( + "{\"knn\":{\"%s\":{\"vector\":%s%s%s}}}", + field, vectorJson.toString(), optionsJson.toString(), filterClause); } private static boolean isNumeric(String str) { diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java index c4c383f9623..7eda27e7c56 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java @@ -34,7 +34,8 @@ public class VectorSearchTableFunctionImplementation extends FunctionExpression implements TableFunctionImplementation { /** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */ - static final Set ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score"); + static final Set ALLOWED_OPTION_KEYS = + Set.of("k", "max_distance", "min_score", "filter_type"); /** * Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores, @@ -99,7 +100,14 @@ public Table applyArguments() { Map options = parseOptions(optionStr); validateOptions(options); - return new VectorSearchIndex(client, settings, tableName, fieldName, vector, options); + // Strip filter_type — it's a SQL-layer directive, not a knn parameter + FilterType filterType = null; + if (options.containsKey("filter_type")) { + filterType = FilterType.fromString(options.remove("filter_type")); + } + + return new VectorSearchIndex( + client, settings, tableName, fieldName, vector, options, filterType); } private float[] parseVector(String vectorLiteral) { @@ -190,6 +198,10 @@ private void validateOptions(Map options) { String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS)); } } + if (options.containsKey("filter_type")) { + // Validate early — fromString throws if invalid + FilterType.fromString(options.get("filter_type")); + } boolean hasK = options.containsKey("k"); boolean hasMaxDistance = options.containsKey("max_distance"); boolean hasMinScore = options.containsKey("min_score"); @@ -197,9 +209,20 @@ private void validateOptions(Map options) { throw new ExpressionEvaluationException( "Missing required option: one of k, max_distance, or min_score"); } + // Mutual exclusivity: exactly one search mode allowed + int modeCount = (hasK ? 1 : 0) + (hasMaxDistance ? 1 : 0) + (hasMinScore ? 1 : 0); + if (modeCount > 1) { + throw new ExpressionEvaluationException( + "Only one of k, max_distance, or min_score may be specified"); + } // Parse and canonicalize numeric values — closes JSON injection via option values if (hasK) { parseIntOption(options, "k"); + int k = Integer.parseInt(options.get("k")); + if (k < 1 || k > 10000) { + throw new ExpressionEvaluationException( + String.format("k must be between 1 and 10000, got %d", k)); + } } if (hasMaxDistance) { parseDoubleOption(options, "max_distance"); diff --git a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java index efc2f333b0d..92b8f9cd469 100644 --- a/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java +++ b/opensearch/src/main/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilder.java @@ -5,32 +5,68 @@ package org.opensearch.sql.opensearch.storage.scan; +import java.util.Map; +import java.util.function.Function; +import org.apache.commons.lang3.tuple.Pair; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.sql.ast.tree.Sort; +import org.opensearch.sql.ast.tree.Sort.SortOption; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.Expression; +import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.FilterType; import org.opensearch.sql.opensearch.storage.script.filter.FilterQueryBuilder; import org.opensearch.sql.opensearch.storage.serde.DefaultExpressionSerializer; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; +import org.opensearch.sql.planner.logical.LogicalSort; /** * Query builder for vector search that keeps the knn query in a scoring (must) context and puts * WHERE filters in a non-scoring (filter) context. This prevents the knn relevance scores from * being destroyed when a WHERE clause is pushed down. * - *

Without this, the default pushDownFilter wraps both queries into bool.filter, which is a - * non-scoring context. + *

Supports two filter placement strategies via {@link FilterType}: + * + *

    + *
  • {@code POST} — WHERE in {@code bool.filter} outside knn (post-filtering, default) + *
  • {@code EFFICIENT} — WHERE inside {@code knn.filter} for pre-filtering during ANN search + *
*/ public class VectorSearchQueryBuilder extends OpenSearchIndexScanQueryBuilder { private final QueryBuilder knnQuery; + private final Map options; + private final FilterType filterType; + private final boolean filterTypeExplicit; + private final Function rebuildKnnWithFilter; + private boolean filterPushed = false; + private boolean limitPushed = false; - public VectorSearchQueryBuilder(OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery) { + /** Full constructor with filter type support. */ + public VectorSearchQueryBuilder( + OpenSearchRequestBuilder requestBuilder, + QueryBuilder knnQuery, + Map options, + FilterType filterType, + boolean filterTypeExplicit, + Function rebuildKnnWithFilter) { super(requestBuilder); - // Set knn as the initial query (scoring context) requestBuilder.getSourceBuilder().query(knnQuery); this.knnQuery = knnQuery; + this.options = options; + this.filterType = filterType != null ? filterType : FilterType.POST; + this.filterTypeExplicit = filterTypeExplicit; + this.rebuildKnnWithFilter = rebuildKnnWithFilter; + } + + /** Backward-compatible constructor — defaults to POST, not explicit. */ + public VectorSearchQueryBuilder( + OpenSearchRequestBuilder requestBuilder, QueryBuilder knnQuery, Map options) { + this(requestBuilder, knnQuery, options, FilterType.POST, false, null); } @Override @@ -38,10 +74,71 @@ public boolean pushDownFilter(LogicalFilter filter) { FilterQueryBuilder queryBuilder = new FilterQueryBuilder(new DefaultExpressionSerializer()); Expression queryCondition = filter.getCondition(); QueryBuilder whereQuery = queryBuilder.build(queryCondition); + filterPushed = true; - // Combine: knn in must (scores), WHERE in filter (no scoring impact) - BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); - requestBuilder.getSourceBuilder().query(combined); + if (filterType == FilterType.EFFICIENT) { + QueryBuilder rebuiltKnn = rebuildKnnWithFilter.apply(whereQuery); + requestBuilder.getSourceBuilder().query(rebuiltKnn); + } else { + // POST mode: knn in must (scores), WHERE in filter (no scoring impact) + BoolQueryBuilder combined = QueryBuilders.boolQuery().must(knnQuery).filter(whereQuery); + requestBuilder.getSourceBuilder().query(combined); + } return true; } + + @Override + public boolean pushDownLimit(LogicalLimit limit) { + if (options.containsKey("k")) { + int k = Integer.parseInt(options.get("k")); + if (limit.getLimit() > k) { + throw new ExpressionEvaluationException( + String.format("LIMIT %d exceeds k=%d in top-k vector search", limit.getLimit(), k)); + } + } + limitPushed = true; + return super.pushDownLimit(limit); + } + + @Override + public boolean pushDownSort(LogicalSort sort) { + // Vector search returns results sorted by _score DESC by default. + // Only _score DESC is meaningful; reject all other sort expressions. + for (Pair sortItem : sort.getSortList()) { + Expression expr = sortItem.getRight(); + if (!(expr instanceof ReferenceExpression) + || !"_score".equals(((ReferenceExpression) expr).getAttr())) { + throw new ExpressionEvaluationException( + String.format( + "vectorSearch only supports ORDER BY _score DESC; " + + "unsupported sort expression: %s", + expr)); + } + if (sortItem.getLeft().getSortOrder() != Sort.SortOrder.DESC) { + throw new ExpressionEvaluationException( + "vectorSearch only supports ORDER BY _score DESC; _score ASC is not supported"); + } + } + // _score DESC is the natural knn order — no need to push the sort itself to OpenSearch. + // Preserve the parent's sort.getCount() → limit pushdown contract: SQL always sets count=0, + // but PPL or future callers may set a non-zero count to combine sort+limit in one node. + if (sort.getCount() != 0) { + requestBuilder.pushDownLimit(sort.getCount(), 0); + } + return true; + } + + @Override + public OpenSearchRequestBuilder build() { + if (filterTypeExplicit && !filterPushed) { + throw new ExpressionEvaluationException("filter_type requires a pushdownable WHERE clause"); + } + boolean isRadial = !options.containsKey("k"); + if (isRadial && !limitPushed) { + throw new ExpressionEvaluationException( + "LIMIT is required for radial vector search (max_distance or min_score)." + + " Without LIMIT, the result set size is unbounded."); + } + return super.build(); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java new file mode 100644 index 00000000000..43ebbb5eabf --- /dev/null +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchIndexTest.java @@ -0,0 +1,229 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.sql.opensearch.storage; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.LinkedHashMap; +import java.util.Map; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.extension.ExtendWith; +import org.mockito.Mock; +import org.mockito.junit.jupiter.MockitoExtension; +import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.opensearch.client.OpenSearchClient; + +@ExtendWith(MockitoExtension.class) +class VectorSearchIndexTest { + + @Mock private OpenSearchClient client; + + @Mock private Settings settings; + + @Test + void buildKnnQueryJsonTopK() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f, 3.0f}, + Map.of("k", "5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0,3.0],\"k\":5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMaxDistance() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f}, + Map.of("max_distance", "10.5")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[1.0,2.0],\"max_distance\":10.5}}}", json); + } + + @Test + void buildKnnQueryJsonRadialMinScore() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {0.5f}, + Map.of("min_score", "0.8")); + + String json = index.buildKnnQueryJson(); + assertEquals("{\"knn\":{\"embedding\":{\"vector\":[0.5],\"min_score\":0.8}}}", json); + } + + @Test + void buildKnnQueryJsonNestedFieldName() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "doc.embedding", + new float[] {1.0f, 2.0f}, + Map.of("k", "10")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"doc.embedding\""), "Should contain nested field name with dot"); + } + + @Test + void buildKnnQueryJsonMultiElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, -2.5f, 0.0f, 3.14f, 100.0f}, + Map.of("k", "3")); + + String json = index.buildKnnQueryJson(); + assertTrue( + json.contains("[1.0,-2.5,0.0,3.14,100.0]"), + "Should contain all vector components with correct comma separation"); + } + + @Test + void buildKnnQueryJsonSingleElementVector() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {42.0f}, Map.of("k", "1")); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("[42.0]"), "Should contain single-element vector"); + } + + @Test + void buildKnnQueryJsonNumericOptionRenderedUnquoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void buildKnnQueryJsonNonNumericOptionRenderedQuoted() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + options.put("method", "hnsw"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, options); + + String json = index.buildKnnQueryJson(); + assertTrue(json.contains("\"method\":\"hnsw\""), "Non-numeric option should be quoted"); + assertTrue(json.contains("\"k\":5"), "Numeric option should be unquoted"); + } + + @Test + void buildKnnQueryJsonWithFilterEmbeds() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f, 2.0f}, + Map.of("k", "5"), + FilterType.EFFICIENT); + + String filterJson = "{\"term\":{\"city\":{\"value\":\"Miami\"}}}"; + String json = index.buildKnnQueryJson(filterJson); + + assertTrue(json.contains("\"filter\""), "Should contain filter field"); + assertTrue(json.contains("\"term\""), "Should contain the filter content"); + assertTrue(json.contains("\"k\":5"), "Should still contain k"); + assertTrue(json.contains("\"vector\":[1.0,2.0]"), "Should contain vector"); + } + + @Test + void buildKnnQueryJsonWithFilterRadial() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + Map.of("max_distance", "10.5"), + FilterType.EFFICIENT); + + String filterJson = "{\"range\":{\"rating\":{\"gte\":4.0}}}"; + String json = index.buildKnnQueryJson(filterJson); + + assertTrue(json.contains("\"max_distance\":10.5"), "Should contain max_distance"); + assertTrue(json.contains("\"filter\""), "Should contain filter"); + } + + @Test + void buildKnnQueryJsonNullFilterProducesBaseJson() { + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + Map.of("k", "5"), + null); + + String json = index.buildKnnQueryJson(null); + String baseJson = index.buildKnnQueryJson(); + + assertEquals(baseJson, json, "null filter should produce same JSON as no-arg version"); + assertFalse(json.contains("\"filter\""), "Should not contain filter field"); + } + + @Test + void buildKnnQueryJsonExcludesFilterType() { + LinkedHashMap options = new LinkedHashMap<>(); + options.put("k", "5"); + + VectorSearchIndex index = + new VectorSearchIndex( + client, + settings, + "test-index", + "embedding", + new float[] {1.0f}, + options, + FilterType.EFFICIENT); + + String json = index.buildKnnQueryJson(); + assertFalse(json.contains("filter_type"), "filter_type should not appear in knn JSON"); + assertTrue(json.contains("\"k\":5"), "k should still be present"); + } + + @Test + void isInstanceOfOpenSearchIndex() { + VectorSearchIndex index = + new VectorSearchIndex( + client, settings, "test-index", "embedding", new float[] {1.0f}, Map.of("k", "5")); + assertTrue(index instanceof OpenSearchIndex); + } +} diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java index 71f0bfa80af..76658ae94b9 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java @@ -233,6 +233,65 @@ void testInfiniteMinScoreThrows() { assertTrue(ex.getMessage().contains("must be a finite number")); } + @Test + void testMutualExclusivityKAndMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityKAndMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testMutualExclusivityAllThreeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs( + "my-index", "embedding", "[1.0, 2.0]", "k=5,max_distance=10.0,min_score=0.5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("Only one of")); + } + + @Test + void testKTooSmallThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=0"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKTooLargeThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10001"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testKBoundaryValuesAllowed() { + // k=1 should work + VectorSearchTableFunctionImplementation impl1 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=1"); + assertTrue(impl1.applyArguments() instanceof VectorSearchIndex); + + // k=10000 should work + VectorSearchTableFunctionImplementation impl2 = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10000"); + assertTrue(impl2.applyArguments() instanceof VectorSearchIndex); + } + @Test void testNonNamedArgThrows() { FunctionName functionName = FunctionName.of("vectorsearch"); @@ -260,6 +319,114 @@ void testNullArgNameThrows() { assertTrue(ex.getMessage().contains("requires named arguments")); } + @Test + void testNaNVectorComponentThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, NaN, 3.0]", "k=5"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testEmptyOptionKeyThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("=value")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testEmptyOptionValueThrows() { + ExpressionEvaluationException ex = + assertThrows( + ExpressionEvaluationException.class, + () -> VectorSearchTableFunctionImplementation.parseOptions("key=")); + assertTrue(ex.getMessage().contains("Malformed option segment")); + } + + @Test + void testNegativeKThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=-1"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("k must be between 1 and 10000")); + } + + @Test + void testNaNMaxDistanceThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "max_distance=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testNaNMinScoreThrows() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "min_score=NaN"); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments()); + assertTrue(ex.getMessage().contains("must be a finite number")); + } + + @Test + void testCaseInsensitiveArgLookup() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("TABLE", DSL.literal("my-index")), + DSL.namedArgument("FIELD", DSL.literal("embedding")), + DSL.namedArgument("VECTOR", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("OPTION", DSL.literal("k=5"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testInvalidFilterTypeRejects() { + FunctionName functionName = FunctionName.of("vectorsearch"); + List args = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0, 2.0]")), + DSL.namedArgument("option", DSL.literal("k=5,filter_type=invalid"))); + VectorSearchTableFunctionImplementation impl = + new VectorSearchTableFunctionImplementation(functionName, args, client, settings); + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, impl::applyArguments); + assertTrue(ex.getMessage().contains("filter_type must be one of")); + } + + @Test + void testFilterTypePostAccepted() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,filter_type=post"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testFilterTypeEfficientAccepted() { + VectorSearchTableFunctionImplementation impl = + createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=5,filter_type=efficient"); + Table table = impl.applyArguments(); + assertTrue(table instanceof VectorSearchIndex); + } + + @Test + void testParseOptionsPreservesFilterTypeValue() { + Map options = + VectorSearchTableFunctionImplementation.parseOptions("k=5,filter_type=post"); + assertEquals("post", options.get("filter_type")); + } + private VectorSearchTableFunctionImplementation createImpl() { return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5"); } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java index 77efd0a6d88..4816dd17fdb 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionResolverTest.java @@ -83,4 +83,48 @@ void testWrongArgumentCount() { IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); assertTrue(ex.getMessage().contains("requires 4 arguments")); } + + @Test + void testTooManyArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = + List.of( + DSL.namedArgument("table", DSL.literal("my-index")), + DSL.namedArgument("field", DSL.literal("embedding")), + DSL.namedArgument("vector", DSL.literal("[1.0]")), + DSL.namedArgument("option", DSL.literal("k=5")), + DSL.namedArgument("extra", DSL.literal("unexpected"))); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } + + @Test + void testZeroArguments() { + VectorSearchTableFunctionResolver resolver = + new VectorSearchTableFunctionResolver(client, settings); + FunctionName functionName = FunctionName.of("vectorsearch"); + List expressions = List.of(); + FunctionSignature functionSignature = + new FunctionSignature( + functionName, expressions.stream().map(Expression::type).collect(Collectors.toList())); + + Pair resolution = resolver.resolve(functionSignature); + FunctionBuilder builder = resolution.getValue(); + + IllegalArgumentException ex = + assertThrows( + IllegalArgumentException.class, () -> builder.apply(functionProperties, expressions)); + assertTrue(ex.getMessage().contains("requires 4 arguments")); + } } diff --git a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java index 2df785b41a0..381fcf646d8 100644 --- a/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java +++ b/opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java @@ -6,21 +6,30 @@ package org.opensearch.sql.opensearch.storage.scan; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.mockito.Mockito.mock; import static org.opensearch.sql.data.type.ExprCoreType.STRING; import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; import org.junit.jupiter.api.Test; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilder; import org.opensearch.index.query.WrapperQueryBuilder; import org.opensearch.sql.common.setting.Settings; +import org.opensearch.sql.data.type.ExprCoreType; +import org.opensearch.sql.exception.ExpressionEvaluationException; import org.opensearch.sql.expression.DSL; import org.opensearch.sql.expression.ReferenceExpression; import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory; import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder; +import org.opensearch.sql.opensearch.storage.FilterType; import org.opensearch.sql.planner.logical.LogicalFilter; +import org.opensearch.sql.planner.logical.LogicalLimit; import org.opensearch.sql.planner.logical.LogicalValues; class VectorSearchQueryBuilderTest { @@ -30,7 +39,7 @@ void knnQuerySetAsScoringQuery() { var requestBuilder = createRequestBuilder(); var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); - new VectorSearchQueryBuilder(requestBuilder, knnQuery); + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); QueryBuilder query = requestBuilder.getSourceBuilder().query(); assertTrue( @@ -42,7 +51,7 @@ void knnQuerySetAsScoringQuery() { void pushDownFilterKeepsKnnInScoringContext() { var requestBuilder = createRequestBuilder(); var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); - var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); // Simulate WHERE name = 'John' var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); @@ -62,6 +71,449 @@ void pushDownFilterKeepsKnnInScoringContext() { "must clause should contain the original knn WrapperQueryBuilder"); } + @Test + void pushDownLimitWithinKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 3, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT within k should succeed"); + } + + @Test + void pushDownLimitExceedingKThrows() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 10, 0); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownLimitEqualToKSucceeds() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 5, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "LIMIT equal to k should succeed"); + } + + @Test + void pushDownLimitRadialModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "Radial mode should not restrict LIMIT"); + } + + @Test + void pushDownLimitMinScoreModeNoRestriction() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("min_score", "0.5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 100, 0); + + boolean pushed = builder.pushDownLimit(limit); + assertTrue(pushed, "min_score mode should not restrict LIMIT"); + } + + @Test + void pushDownSortScoreDescAccepted() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC should be accepted"); + } + + @Test + void pushDownSortPreservesSortCountAsLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "10")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + // LogicalSort with count=7 simulates a sort+limit combined node (PPL path) + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + 7, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC with count should be accepted"); + assertEquals( + 7, + requestBuilder.getMaxResponseSize(), + "sort.getCount() should be pushed down as request size"); + } + + @Test + void pushDownSortNonScoreFieldRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortMultipleExpressionsRejectsNonScore() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)), + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + + @Test + void pushDownSortScoreAscRejected() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("_score ASC is not supported")); + } + + @Test + void pushDownFilterCompoundPredicateSurvives() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + // Simulate WHERE name = 'John' AND age > 30 + var condition = + DSL.and( + DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")), + DSL.greater(new ReferenceExpression("age", ExprCoreType.INTEGER), DSL.literal(30))); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter with compound predicate should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery"); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)"); + assertEquals(1, boolQuery.filter().size(), "compound WHERE should be in filter (non-scoring)"); + } + + @Test + void pushDownFilterEfficientPlacesInsideKnn() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + // Callback simulates VectorSearchIndex rebuilding knn with filter + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{\"filter\":\"embedded\"}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var condition = DSL.equal(new ReferenceExpression("city", STRING), DSL.literal("Miami")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed, "pushDownFilter should succeed"); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue( + resultQuery instanceof WrapperQueryBuilder, + "Efficient filter should produce a WrapperQueryBuilder (rebuilt knn), not BoolQuery"); + } + + @Test + void pushDownFilterExplicitPostProducesBool() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + var filter = new LogicalFilter(dummyChild, condition); + + boolean pushed = builder.pushDownFilter(filter); + + assertTrue(pushed); + QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query(); + assertTrue(resultQuery instanceof BoolQueryBuilder); + BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery; + assertEquals(1, boolQuery.must().size()); + assertEquals(1, boolQuery.filter().size()); + } + + // ── Build-time validation ──────────────────────────────────────────── + + @Test + void buildRejectsExplicitFilterTypePostWithoutWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("filter_type requires a pushdownable WHERE clause")); + } + + @Test + void buildRejectsExplicitFilterTypeEfficientWithoutWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{\"filter\":\"embedded\"}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("filter_type requires a pushdownable WHERE clause")); + } + + @Test + void buildSucceedsWithNoFilterTypeAndNoWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + @Test + void buildSucceedsWithFilterTypeAndPushedWhere() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, knnQuery, Map.of("k", "5"), FilterType.POST, true, null); + + var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John")); + var dummyChild = new LogicalValues(Collections.emptyList()); + builder.pushDownFilter(new LogicalFilter(dummyChild, condition)); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + // ── Radial without LIMIT rejection ───────────────────────────────── + + @Test + void buildRejectsRadialMaxDistanceWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("LIMIT is required for radial vector search")); + } + + @Test + void buildRejectsRadialMinScoreWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("min_score", "0.5")); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, builder::build); + assertTrue(ex.getMessage().contains("LIMIT is required for radial vector search")); + } + + @Test + void buildSucceedsRadialWithLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("max_distance", "10.0")); + + var dummyChild = new LogicalValues(Collections.emptyList()); + builder.pushDownLimit(new LogicalLimit(dummyChild, 50, 0)); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + @Test + void buildSucceedsTopKWithoutLimit() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery, Map.of("k", "5")); + + OpenSearchRequestBuilder result = builder.build(); + assertNotNull(result); + } + + // ── Regression: LIMIT and sort invariants under efficient mode ────── + + @Test + void pushDownLimitExceedingKThrowsUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var limit = new LogicalLimit(dummyChild, 10, 0); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownLimit(limit)); + assertTrue(ex.getMessage().contains("LIMIT 10 exceeds k=5")); + } + + @Test + void pushDownSortScoreDescAcceptedUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_DESC, + new ReferenceExpression("_score", ExprCoreType.FLOAT)))); + + boolean pushed = builder.pushDownSort(sort); + assertTrue(pushed, "ORDER BY _score DESC should be accepted under efficient mode"); + } + + @Test + void pushDownSortNonScoreRejectedUnderEfficientMode() { + var requestBuilder = createRequestBuilder(); + var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}"); + Function rebuildWithFilter = + whereQuery -> new WrapperQueryBuilder("{\"knn\":{}}"); + var builder = + new VectorSearchQueryBuilder( + requestBuilder, + knnQuery, + Map.of("k", "5"), + FilterType.EFFICIENT, + true, + rebuildWithFilter); + + var dummyChild = new LogicalValues(Collections.emptyList()); + var sort = + new org.opensearch.sql.planner.logical.LogicalSort( + dummyChild, + List.of( + org.apache.commons.lang3.tuple.ImmutablePair.of( + org.opensearch.sql.ast.tree.Sort.SortOption.DEFAULT_ASC, + new ReferenceExpression("name", STRING)))); + + ExpressionEvaluationException ex = + assertThrows(ExpressionEvaluationException.class, () -> builder.pushDownSort(sort)); + assertTrue(ex.getMessage().contains("unsupported sort expression")); + } + private OpenSearchRequestBuilder createRequestBuilder() { return new OpenSearchRequestBuilder( mock(OpenSearchExprValueFactory.class), 10000, mock(Settings.class));