Skip to content

Commit cba9f44

Browse files
committed
Skip validations for bin-on-timestamps (1799/2027)
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 1ffd721 commit cba9f44

5 files changed

Lines changed: 131 additions & 14 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/validate/PplRelToSqlRelShuttle.java renamed to core/src/main/java/org/opensearch/sql/calcite/validate/shuttles/PplRelToSqlRelShuttle.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
* SPDX-License-Identifier: Apache-2.0
44
*/
55

6-
package org.opensearch.sql.calcite.validate;
6+
package org.opensearch.sql.calcite.validate.shuttles;
77

88
import java.math.BigDecimal;
99
import org.apache.calcite.avatica.util.TimeUnit;
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
/*
2+
* Copyright OpenSearch Contributors
3+
* SPDX-License-Identifier: Apache-2.0
4+
*/
5+
6+
package org.opensearch.sql.calcite.validate.shuttles;
7+
8+
import java.util.List;
9+
import java.util.function.Predicate;
10+
import org.apache.calcite.rel.RelNode;
11+
import org.apache.calcite.rel.RelShuttleImpl;
12+
import org.apache.calcite.rex.RexCall;
13+
import org.apache.calcite.rex.RexNode;
14+
import org.apache.calcite.rex.RexShuttle;
15+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
16+
17+
/**
18+
* A RelShuttle that detects if validation should be skipped for certain operations. Currently, it
19+
* detects the following patterns:
20+
*
21+
* <ul>
22+
* <li>binning on datetime types, which is only executable after pushdown.
23+
* </ul>
24+
*/
25+
public class SkipRelValidationShuttle extends RelShuttleImpl {
26+
private boolean shouldSkip = false;
27+
private final RexShuttle rexShuttle;
28+
29+
/** Predicates about patterns of calls that should not be validated. */
30+
public static final List<Predicate<RexCall>> SKIP_CALLS;
31+
32+
static {
33+
Predicate<RexCall> binOnTimestamp =
34+
call -> {
35+
if ("WIDTH_BUCKET".equalsIgnoreCase(call.getOperator().getName())) {
36+
if (!call.getOperands().isEmpty()) {
37+
RexNode firstOperand = call.getOperands().get(0);
38+
return OpenSearchTypeFactory.isDatetime(firstOperand.getType());
39+
}
40+
}
41+
return false;
42+
};
43+
SKIP_CALLS = List.of(binOnTimestamp);
44+
}
45+
46+
public SkipRelValidationShuttle() {
47+
this.rexShuttle =
48+
new RexShuttle() {
49+
@Override
50+
public RexNode visitCall(RexCall call) {
51+
for (Predicate<RexCall> skipCall : SKIP_CALLS) {
52+
if (skipCall.test(call)) {
53+
shouldSkip = true;
54+
return call;
55+
}
56+
}
57+
return super.visitCall(call);
58+
}
59+
};
60+
}
61+
62+
/** Returns true if validation should be skipped based on detected conditions. */
63+
public boolean shouldSkipValidation() {
64+
return shouldSkip;
65+
}
66+
67+
@Override
68+
protected RelNode visitChild(RelNode parent, int i, RelNode child) {
69+
RelNode newChild = super.visitChild(parent, i, child);
70+
return newChild.accept(rexShuttle);
71+
}
72+
}

core/src/main/java/org/opensearch/sql/executor/QueryService.java

Lines changed: 54 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,12 @@
55

66
package org.opensearch.sql.executor;
77

8+
import java.util.ArrayList;
89
import java.util.Collections;
10+
import java.util.HashSet;
911
import java.util.List;
1012
import java.util.Optional;
13+
import java.util.stream.Collectors;
1114
import javax.annotation.Nullable;
1215
import lombok.AllArgsConstructor;
1316
import lombok.Getter;
@@ -35,10 +38,15 @@
3538
import org.apache.calcite.sql.SqlBasicCall;
3639
import org.apache.calcite.sql.SqlCall;
3740
import org.apache.calcite.sql.SqlIdentifier;
41+
import org.apache.calcite.sql.SqlKind;
42+
import org.apache.calcite.sql.SqlLiteral;
3843
import org.apache.calcite.sql.SqlNode;
44+
import org.apache.calcite.sql.SqlNodeList;
45+
import org.apache.calcite.sql.SqlSelect;
3946
import org.apache.calcite.sql.fun.SqlCountAggFunction;
4047
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
4148
import org.apache.calcite.sql.parser.SqlParser;
49+
import org.apache.calcite.sql.parser.SqlParserPos;
4250
import org.apache.calcite.sql.util.SqlShuttle;
4351
import org.apache.calcite.sql.validate.SqlValidator;
4452
import org.apache.calcite.sql2rel.SqlToRelConverter;
@@ -58,7 +66,8 @@
5866
import org.opensearch.sql.calcite.validate.OpenSearchSparkSqlDialect;
5967
import org.opensearch.sql.calcite.validate.PplConvertletTable;
6068
import org.opensearch.sql.calcite.validate.PplRelToSqlNodeConverter;
61-
import org.opensearch.sql.calcite.validate.PplRelToSqlRelShuttle;
69+
import org.opensearch.sql.calcite.validate.shuttles.PplRelToSqlRelShuttle;
70+
import org.opensearch.sql.calcite.validate.shuttles.SkipRelValidationShuttle;
6271
import org.opensearch.sql.common.response.ResponseListener;
6372
import org.opensearch.sql.common.setting.Settings;
6473
import org.opensearch.sql.datasource.DataSourceService;
@@ -306,6 +315,16 @@ public LogicalPlan analyze(UnresolvedPlan plan, QueryType queryType) {
306315
* @return the validated (and potentially modified) relation node
307316
*/
308317
private RelNode validate(RelNode relNode, CalcitePlanContext context) {
318+
SkipRelValidationShuttle skipShuttle = new SkipRelValidationShuttle();
319+
relNode.accept(skipShuttle);
320+
// WARNING: When a skip pattern is detected (e.g., WIDTH_BUCKET on datetime types),
321+
// we bypass the entire validation pipeline, skipping potentially useful transformation relying
322+
// on rewriting SQL node
323+
// TODO: Make incompatible operations like bin-on-timestamp a validatable UDFs so that they can
324+
// be still be converted to SqlNode and back to RelNode
325+
if (skipShuttle.shouldSkipValidation()) {
326+
return relNode;
327+
}
309328
// Fix interval literals before conversion to SQL
310329
RelNode sqlRelNode = relNode.accept(new PplRelToSqlRelShuttle(context.rexBuilder, true));
311330

@@ -346,7 +365,6 @@ public SqlNode visit(SqlIdentifier id) {
346365
return super.visit(call);
347366
}
348367
});
349-
350368
SqlValidator validator = context.getValidator();
351369
if (rewritten != null) {
352370
try {
@@ -361,6 +379,9 @@ public SqlNode visit(SqlIdentifier id) {
361379
return relNode;
362380
}
363381

382+
// if (rewritten instanceof SqlSelect select) {
383+
// rewritten = rewriteGroupBy(select);
384+
// }
364385
// Convert the validated SqlNode back to RelNode
365386
RelOptTable.ViewExpander viewExpander = context.config.getViewExpander();
366387
RelOptCluster cluster = context.relBuilder.getCluster();
@@ -463,4 +484,35 @@ private static RelNode convertToCalcitePlan(RelNode osPlan) {
463484
}
464485
return calcitePlan;
465486
}
487+
488+
private SqlNode rewriteGroupBy(SqlSelect root) {
489+
if (root.getGroup() == null) {
490+
return root;
491+
}
492+
List<SqlNode> selectList = root.getSelectList().getList();
493+
List<SqlNode> groupByList = root.getGroup().getList();
494+
List<SqlNode> unwrappedGroupByList = groupByList.stream().map(QueryService::unwrapAs).toList();
495+
List<SqlNode> unwrappedSelectList = selectList.stream().map(QueryService::unwrapAs).toList();
496+
if (new HashSet<>(unwrappedSelectList).containsAll(unwrappedGroupByList)) {
497+
List<Integer> ordinals =
498+
unwrappedGroupByList.stream().map(unwrappedSelectList::indexOf).toList();
499+
List<SqlNode> groupByOrdinals =
500+
ordinals.stream()
501+
.map(
502+
ordinal ->
503+
(SqlNode)
504+
SqlLiteral.createExactNumeric(
505+
Integer.toString(ordinal + 1), SqlParserPos.ZERO))
506+
.collect(Collectors.toCollection(ArrayList::new));
507+
root.setGroupBy(SqlNodeList.of(root.getGroup().getParserPosition(), groupByOrdinals));
508+
}
509+
return root;
510+
}
511+
512+
private static SqlNode unwrapAs(SqlNode node) {
513+
if (node.getKind() == SqlKind.AS && node instanceof SqlCall) {
514+
return ((SqlCall) node).getOperandList().get(0);
515+
}
516+
return node;
517+
}
466518
}

core/src/main/java/org/opensearch/sql/expression/function/udf/binning/WidthBucketFunction.java

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,8 +16,7 @@
1616
import org.apache.calcite.rex.RexCall;
1717
import org.apache.calcite.sql.type.SqlReturnTypeInference;
1818
import org.apache.calcite.sql.type.SqlTypeName;
19-
import org.opensearch.sql.calcite.type.ExprSqlType;
20-
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT;
19+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
2120
import org.opensearch.sql.calcite.utils.PPLOperandTypes;
2221
import org.opensearch.sql.calcite.utils.binning.BinConstants;
2322
import org.opensearch.sql.expression.function.ImplementorUDF;
@@ -51,19 +50,13 @@ public SqlReturnTypeInference getReturnTypeInference() {
5150
return (opBinding) -> {
5251
RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
5352
RelDataType arg0Type = opBinding.getOperandType(0);
54-
return dateRelatedType(arg0Type)
53+
return OpenSearchTypeFactory.isDatetime(arg0Type)
5554
? arg0Type
5655
: typeFactory.createTypeWithNullability(
5756
typeFactory.createSqlType(SqlTypeName.VARCHAR, 2000), true);
5857
};
5958
}
6059

61-
public static boolean dateRelatedType(RelDataType type) {
62-
return type instanceof ExprSqlType exprSqlType
63-
&& List.of(ExprUDT.EXPR_DATE, ExprUDT.EXPR_TIME, ExprUDT.EXPR_TIMESTAMP)
64-
.contains(exprSqlType.getUdt());
65-
}
66-
6760
@Override
6861
public UDFOperandMetadata getOperandMetadata() {
6962
return PPLOperandTypes.WIDTH_BUCKET_OPERAND;

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/rules/AggregateIndexScanRule.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,9 +31,9 @@
3131
import org.immutables.value.Value;
3232
import org.opensearch.sql.ast.expression.Argument;
3333
import org.opensearch.sql.calcite.plan.OpenSearchRuleConfig;
34+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
3435
import org.opensearch.sql.calcite.utils.PlanUtils;
3536
import org.opensearch.sql.expression.function.BuiltinFunctionName;
36-
import org.opensearch.sql.expression.function.udf.binning.WidthBucketFunction;
3737
import org.opensearch.sql.opensearch.storage.scan.AbstractCalciteIndexScan;
3838
import org.opensearch.sql.opensearch.storage.scan.CalciteLogicalIndexScan;
3939

@@ -306,7 +306,7 @@ static boolean containsWidthBucketFuncOnDate(LogicalProject project) {
306306
expr ->
307307
expr instanceof RexCall rexCall
308308
&& rexCall.getOperator().equals(WIDTH_BUCKET)
309-
&& WidthBucketFunction.dateRelatedType(
309+
&& OpenSearchTypeFactory.isDatetime(
310310
rexCall.getOperands().getFirst().getType()));
311311
}
312312
}

0 commit comments

Comments
 (0)