Skip to content

Commit c183c01

Browse files
committed
Harden input validation and add size=k default for top-k mode
- parseOptions: reject malformed segments and duplicate keys - parseVector: wrap errors in ExpressionEvaluationException, reject non-finite floats (Infinity, NaN) - VectorSearchIndex: default requestedTotalSize to k via pushDownLimitToRequestTotal so queries without LIMIT return k results - Add 5 new tests: malformed option, duplicate key, empty vector, malformed vector component, non-finite vector component Signed-off-by: Eric Wei <mengwei.eric@gmail.com>
1 parent 9a4955a commit c183c01

3 files changed

Lines changed: 79 additions & 4 deletions

File tree

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchIndex.java

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,13 @@ public TableScanBuilder createScanBuilder() {
5353
var queryBuilder = new VectorSearchQueryBuilder(requestBuilder, buildKnnQuery());
5454
requestBuilder.pushDownTrackedScore(true);
5555

56+
// Top-k mode: default size to k so queries without LIMIT return k results
57+
// instead of falling into the generic large-scan path.
58+
// LIMIT pushdown will further reduce this if present.
59+
if (options.containsKey("k")) {
60+
requestBuilder.pushDownLimitToRequestTotal(Integer.parseInt(options.get("k")), 0);
61+
}
62+
5663
Function<OpenSearchRequestBuilder, OpenSearchIndexScan> createScanOperator =
5764
rb ->
5865
new OpenSearchIndexScan(

opensearch/src/main/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementation.java

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -104,21 +104,44 @@ public Table applyArguments() {
104104

105105
private float[] parseVector(String vectorLiteral) {
106106
String cleaned = vectorLiteral.replaceAll("[\\[\\]]", "").trim();
107+
if (cleaned.isEmpty()) {
108+
throw new ExpressionEvaluationException("Vector literal must not be empty");
109+
}
107110
String[] parts = cleaned.split(",");
108111
float[] vector = new float[parts.length];
109112
for (int i = 0; i < parts.length; i++) {
110-
vector[i] = Float.parseFloat(parts[i].trim());
113+
try {
114+
vector[i] = Float.parseFloat(parts[i].trim());
115+
} catch (NumberFormatException e) {
116+
throw new ExpressionEvaluationException(
117+
String.format("Invalid vector component '%s': must be a number", parts[i].trim()));
118+
}
119+
if (!Float.isFinite(vector[i])) {
120+
throw new ExpressionEvaluationException(
121+
String.format(
122+
"Invalid vector component '%s': must be a finite number", parts[i].trim()));
123+
}
111124
}
112125
return vector;
113126
}
114127

115128
static Map<String, String> parseOptions(String optionStr) {
116129
Map<String, String> options = new LinkedHashMap<>();
117130
for (String pair : optionStr.split(",")) {
118-
String[] kv = pair.trim().split("=", 2);
119-
if (kv.length == 2) {
120-
options.put(kv[0].trim(), kv[1].trim());
131+
String trimmed = pair.trim();
132+
if (trimmed.isEmpty()) {
133+
continue;
134+
}
135+
String[] kv = trimmed.split("=", 2);
136+
if (kv.length != 2 || kv[0].trim().isEmpty() || kv[1].trim().isEmpty()) {
137+
throw new ExpressionEvaluationException(
138+
String.format("Malformed option segment '%s': expected key=value", trimmed));
139+
}
140+
String key = kv[0].trim();
141+
if (options.containsKey(key)) {
142+
throw new ExpressionEvaluationException(String.format("Duplicate option key '%s'", key));
121143
}
144+
options.put(key, kv[1].trim());
122145
}
123146
return options;
124147
}

opensearch/src/test/java/org/opensearch/sql/opensearch/storage/VectorSearchTableFunctionImplementationTest.java

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,51 @@ void testParseOptionsMultiple() {
120120
assertEquals("10.0", opts.get("max_distance"));
121121
}
122122

123+
@Test
124+
void testMalformedOptionSegmentThrows() {
125+
ExpressionEvaluationException ex =
126+
assertThrows(
127+
ExpressionEvaluationException.class,
128+
() -> VectorSearchTableFunctionImplementation.parseOptions("k=5,badoption"));
129+
assertTrue(ex.getMessage().contains("Malformed option segment"));
130+
}
131+
132+
@Test
133+
void testDuplicateOptionKeyThrows() {
134+
ExpressionEvaluationException ex =
135+
assertThrows(
136+
ExpressionEvaluationException.class,
137+
() -> VectorSearchTableFunctionImplementation.parseOptions("k=5,k=10"));
138+
assertTrue(ex.getMessage().contains("Duplicate option key"));
139+
}
140+
141+
@Test
142+
void testEmptyVectorThrows() {
143+
VectorSearchTableFunctionImplementation impl =
144+
createImplWithArgs("my-index", "embedding", "[]", "k=5");
145+
ExpressionEvaluationException ex =
146+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
147+
assertTrue(ex.getMessage().contains("must not be empty"));
148+
}
149+
150+
@Test
151+
void testMalformedVectorComponentThrows() {
152+
VectorSearchTableFunctionImplementation impl =
153+
createImplWithArgs("my-index", "embedding", "[1.0, abc, 3.0]", "k=5");
154+
ExpressionEvaluationException ex =
155+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
156+
assertTrue(ex.getMessage().contains("Invalid vector component"));
157+
}
158+
159+
@Test
160+
void testNonFiniteVectorComponentThrows() {
161+
VectorSearchTableFunctionImplementation impl =
162+
createImplWithArgs("my-index", "embedding", "[1.0, Infinity, 3.0]", "k=5");
163+
ExpressionEvaluationException ex =
164+
assertThrows(ExpressionEvaluationException.class, () -> impl.applyArguments());
165+
assertTrue(ex.getMessage().contains("must be a finite number"));
166+
}
167+
123168
@Test
124169
void testMissingArgumentThrows() {
125170
FunctionName functionName = FunctionName.of("vectorsearch");

0 commit comments

Comments
 (0)