Skip to content

Commit 63253e3

Browse files
committed
Address review feedback: add validation guards and pushDownFilter test
- Add pushDownFilter() unit test asserting knn stays in bool.must (scoring) and WHERE predicate goes to bool.filter (non-scoring) - Add option key allowlist (k, max_distance, min_score) to reject unknown/unsupported keys before they reach DSL generation - Add field name validation to reject characters that could corrupt the WrapperQueryBuilder JSON (allows alphanumeric, dots, underscores, hyphens) - Add named-arg type guard to reject non-NamedArgumentExpression args early with a clear error message Signed-off-by: Eric Wei <mengwei.eric@gmail.com>
1 parent 5b421fe commit 63253e3

3 files changed

Lines changed: 109 additions & 14 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
import java.util.LinkedHashMap;
1414
import java.util.List;
1515
import java.util.Map;
16+
import java.util.Set;
17+
import java.util.regex.Pattern;
1618
import java.util.stream.Collectors;
1719
import org.opensearch.sql.common.setting.Settings;
1820
import org.opensearch.sql.data.model.ExprValue;
@@ -31,6 +33,15 @@
3133
public class VectorSearchTableFunctionImplementation extends FunctionExpression
3234
implements TableFunctionImplementation {
3335

36+
/** P0 allowed option keys. Rejects unknown/future keys to prevent unvalidated DSL injection. */
37+
static final Set<String> ALLOWED_OPTION_KEYS = Set.of("k", "max_distance", "min_score");
38+
39+
/**
40+
* Field names must be safe for JSON interpolation: alphanumeric, dots (nested), underscores,
41+
* hyphens. Rejects characters that could corrupt the WrapperQueryBuilder JSON.
42+
*/
43+
private static final Pattern SAFE_FIELD_NAME = Pattern.compile("^[a-zA-Z0-9._\\-]+$");
44+
3445
private final FunctionName functionName;
3546
private final List<Expression> arguments;
3647
private final OpenSearchClient client;
@@ -75,8 +86,10 @@ public String toString() {
7586

7687
@Override
7788
public Table applyArguments() {
89+
validateNamedArgs();
7890
String tableName = getArgumentValue(TABLE);
7991
String fieldName = getArgumentValue(FIELD);
92+
validateFieldName(fieldName);
8093
String vectorLiteral = getArgumentValue(VECTOR);
8194
String optionStr = getArgumentValue(OPTION);
8295

@@ -108,7 +121,40 @@ static Map<String, String> parseOptions(String optionStr) {
108121
return options;
109122
}
110123

124+
/** Reject non-named arguments early. vectorSearch() requires named args (key=value). */
125+
private void validateNamedArgs() {
126+
for (Expression arg : arguments) {
127+
if (!(arg instanceof NamedArgumentExpression)) {
128+
throw new ExpressionEvaluationException(
129+
"vectorSearch() requires named arguments (e.g., table='index'), "
130+
+ "but received: "
131+
+ arg.getClass().getSimpleName());
132+
}
133+
}
134+
}
135+
136+
/**
137+
* Reject field names with characters that could corrupt the WrapperQueryBuilder JSON. Allows
138+
* alphanumeric, dots (nested fields), underscores, and hyphens.
139+
*/
140+
private void validateFieldName(String fieldName) {
141+
if (!SAFE_FIELD_NAME.matcher(fieldName).matches()) {
142+
throw new ExpressionEvaluationException(
143+
String.format(
144+
"Invalid field name '%s': must contain only alphanumeric characters,"
145+
+ " dots, underscores, or hyphens",
146+
fieldName));
147+
}
148+
}
149+
111150
private void validateOptions(Map<String, String> options) {
151+
// Reject unknown option keys — only P0 keys are allowed
152+
for (String key : options.keySet()) {
153+
if (!ALLOWED_OPTION_KEYS.contains(key)) {
154+
throw new ExpressionEvaluationException(
155+
String.format("Unknown option key '%s'. Supported keys: %s", key, ALLOWED_OPTION_KEYS));
156+
}
157+
}
112158
boolean hasK = options.containsKey("k");
113159
boolean hasMaxDistance = options.containsKey("max_distance");
114160
boolean hasMinScore = options.containsKey("min_score");

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,13 @@ void testApplyArgumentsWithUnbracketedVector() {
7878
}
7979

8080
@Test
81-
void testApplyArgumentsWithComplexOptions() {
81+
void testUnknownOptionKeyThrows() {
8282
VectorSearchTableFunctionImplementation impl =
8383
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "k=10,method.ef_search=100");
84-
Table table = impl.applyArguments();
85-
assertTrue(table instanceof VectorSearchIndex);
84+
ExpressionEvaluationException ex =
85+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
86+
assertTrue(ex.getMessage().contains("Unknown option key"));
87+
assertTrue(ex.getMessage().contains("method.ef_search"));
8688
}
8789

8890
@Test
@@ -104,18 +106,18 @@ void testApplyArgumentsWithMinScore() {
104106
@Test
105107
void testMissingSearchModeOptionThrows() {
106108
VectorSearchTableFunctionImplementation impl =
107-
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "method.ef_search=100");
109+
createImplWithArgs("my-index", "embedding", "[1.0, 2.0]", "not_a_key=100");
108110
ExpressionEvaluationException ex =
109111
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
110-
assertTrue(ex.getMessage().contains("one of k, max_distance, or min_score"));
112+
assertTrue(ex.getMessage().contains("Unknown option key"));
111113
}
112114

113115
@Test
114116
void testParseOptionsMultiple() {
115117
Map<String, String> opts =
116-
VectorSearchTableFunctionImplementation.parseOptions("k=5,method.ef_search=100");
118+
VectorSearchTableFunctionImplementation.parseOptions("k=5,max_distance=10.0");
117119
assertEquals("5", opts.get("k"));
118-
assertEquals("100", opts.get("method.ef_search"));
120+
assertEquals("10.0", opts.get("max_distance"));
119121
}
120122

121123
@Test
@@ -133,6 +135,34 @@ void testMissingArgumentThrows() {
133135
assertEquals("Missing required argument: option", ex.getMessage());
134136
}
135137

138+
@Test
139+
void testInvalidFieldNameThrows() {
140+
VectorSearchTableFunctionImplementation impl =
141+
createImplWithArgs("my-index", "field\"injection", "[1.0, 2.0]", "k=5");
142+
ExpressionEvaluationException ex =
143+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
144+
assertTrue(ex.getMessage().contains("Invalid field name"));
145+
}
146+
147+
@Test
148+
void testNestedFieldNameAllowed() {
149+
VectorSearchTableFunctionImplementation impl =
150+
createImplWithArgs("my-index", "doc.embedding", "[1.0, 2.0]", "k=5");
151+
Table table = impl.applyArguments();
152+
assertTrue(table instanceof VectorSearchIndex);
153+
}
154+
155+
@Test
156+
void testNonNamedArgThrows() {
157+
FunctionName functionName = FunctionName.of("vectorsearch");
158+
List<Expression> args = List.of(DSL.literal("my-index"));
159+
VectorSearchTableFunctionImplementation impl =
160+
new VectorSearchTableFunctionImplementation(functionName, args, client, settings);
161+
ExpressionEvaluationException ex =
162+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
163+
assertTrue(ex.getMessage().contains("requires named arguments"));
164+
}
165+
136166
private VectorSearchTableFunctionImplementation createImpl() {
137167
return createImplWithArgs("my-index", "embedding", "[1.0, 2.0, 3.0]", "k=5");
138168
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/scan/VectorSearchQueryBuilderTest.java

Lines changed: 26 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,30 @@
55

66
package org.opensearch.sql.opensearch.storage.scan;
77

8+
import static org.junit.jupiter.api.Assertions.assertEquals;
89
import static org.junit.jupiter.api.Assertions.assertTrue;
910
import static org.mockito.Mockito.mock;
11+
import static org.opensearch.sql.data.type.ExprCoreType.STRING;
1012

13+
import java.util.Collections;
1114
import org.junit.jupiter.api.Test;
15+
import org.opensearch.index.query.BoolQueryBuilder;
1216
import org.opensearch.index.query.QueryBuilder;
1317
import org.opensearch.index.query.WrapperQueryBuilder;
1418
import org.opensearch.sql.common.setting.Settings;
19+
import org.opensearch.sql.expression.DSL;
20+
import org.opensearch.sql.expression.ReferenceExpression;
1521
import org.opensearch.sql.opensearch.data.value.OpenSearchExprValueFactory;
1622
import org.opensearch.sql.opensearch.request.OpenSearchRequestBuilder;
23+
import org.opensearch.sql.planner.logical.LogicalFilter;
24+
import org.opensearch.sql.planner.logical.LogicalValues;
1725

1826
class VectorSearchQueryBuilderTest {
1927

2028
@Test
2129
void knnQuerySetAsScoringQuery() {
2230
var requestBuilder = createRequestBuilder();
23-
var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ==");
31+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
2432

2533
new VectorSearchQueryBuilder(requestBuilder, knnQuery);
2634

@@ -31,16 +39,27 @@ void knnQuerySetAsScoringQuery() {
3139
}
3240

3341
@Test
34-
void knnQueryNotWrappedInFilterWhenNoWhere() {
42+
void pushDownFilterKeepsKnnInScoringContext() {
3543
var requestBuilder = createRequestBuilder();
36-
var knnQuery = new WrapperQueryBuilder("eyJrbm4iOnt9fQ==");
44+
var knnQuery = new WrapperQueryBuilder("{\"knn\":{}}");
45+
var builder = new VectorSearchQueryBuilder(requestBuilder, knnQuery);
3746

38-
new VectorSearchQueryBuilder(requestBuilder, knnQuery);
47+
// Simulate WHERE name = 'John'
48+
var condition = DSL.equal(new ReferenceExpression("name", STRING), DSL.literal("John"));
49+
var dummyChild = new LogicalValues(Collections.emptyList());
50+
var filter = new LogicalFilter(dummyChild, condition);
3951

40-
QueryBuilder query = requestBuilder.getSourceBuilder().query();
52+
boolean pushed = builder.pushDownFilter(filter);
53+
54+
assertTrue(pushed, "pushDownFilter should succeed");
55+
QueryBuilder resultQuery = requestBuilder.getSourceBuilder().query();
56+
assertTrue(resultQuery instanceof BoolQueryBuilder, "Result should be a BoolQuery");
57+
BoolQueryBuilder boolQuery = (BoolQueryBuilder) resultQuery;
58+
assertEquals(1, boolQuery.must().size(), "knn query should be in must (scoring context)");
59+
assertEquals(1, boolQuery.filter().size(), "WHERE predicate should be in filter (non-scoring)");
4160
assertTrue(
42-
query instanceof WrapperQueryBuilder,
43-
"Without WHERE clause, knn query should NOT be wrapped in bool.filter");
61+
boolQuery.must().get(0) instanceof WrapperQueryBuilder,
62+
"must clause should contain the original knn WrapperQueryBuilder");
4463
}
4564

4665
private OpenSearchRequestBuilder createRequestBuilder() {

0 commit comments

Comments
 (0)