55
66package org .opensearch .sql .executor ;
77
8+ import java .util .ArrayList ;
89import java .util .Collections ;
10+ import java .util .HashSet ;
911import java .util .List ;
1012import java .util .Optional ;
13+ import java .util .stream .Collectors ;
1114import javax .annotation .Nullable ;
1215import lombok .AllArgsConstructor ;
1316import lombok .Getter ;
3538import org .apache .calcite .sql .SqlBasicCall ;
3639import org .apache .calcite .sql .SqlCall ;
3740import org .apache .calcite .sql .SqlIdentifier ;
41+ import org .apache .calcite .sql .SqlKind ;
42+ import org .apache .calcite .sql .SqlLiteral ;
3843import org .apache .calcite .sql .SqlNode ;
44+ import org .apache .calcite .sql .SqlNodeList ;
45+ import org .apache .calcite .sql .SqlSelect ;
3946import org .apache .calcite .sql .fun .SqlCountAggFunction ;
4047import org .apache .calcite .sql .fun .SqlStdOperatorTable ;
4148import org .apache .calcite .sql .parser .SqlParser ;
49+ import org .apache .calcite .sql .parser .SqlParserPos ;
4250import org .apache .calcite .sql .util .SqlShuttle ;
4351import org .apache .calcite .sql .validate .SqlValidator ;
4452import org .apache .calcite .sql2rel .SqlToRelConverter ;
5866import org .opensearch .sql .calcite .validate .OpenSearchSparkSqlDialect ;
5967import org .opensearch .sql .calcite .validate .PplConvertletTable ;
6068import org .opensearch .sql .calcite .validate .PplRelToSqlNodeConverter ;
61- import org .opensearch .sql .calcite .validate .PplRelToSqlRelShuttle ;
69+ import org .opensearch .sql .calcite .validate .shuttles .PplRelToSqlRelShuttle ;
70+ import org .opensearch .sql .calcite .validate .shuttles .SkipRelValidationShuttle ;
6271import org .opensearch .sql .common .response .ResponseListener ;
6372import org .opensearch .sql .common .setting .Settings ;
6473import org .opensearch .sql .datasource .DataSourceService ;
@@ -306,6 +315,16 @@ public LogicalPlan analyze(UnresolvedPlan plan, QueryType queryType) {
306315 * @return the validated (and potentially modified) relation node
307316 */
308317 private RelNode validate (RelNode relNode , CalcitePlanContext context ) {
318+ SkipRelValidationShuttle skipShuttle = new SkipRelValidationShuttle ();
319+ relNode .accept (skipShuttle );
320+ // WARNING: When a skip pattern is detected (e.g., WIDTH_BUCKET on datetime types),
321+ // we bypass the entire validation pipeline, skipping potentially useful transformation relying
322+ // on rewriting SQL node
323+ // TODO: Make incompatible operations like bin-on-timestamp a validatable UDFs so that they can
324+ // be still be converted to SqlNode and back to RelNode
325+ if (skipShuttle .shouldSkipValidation ()) {
326+ return relNode ;
327+ }
309328 // Fix interval literals before conversion to SQL
310329 RelNode sqlRelNode = relNode .accept (new PplRelToSqlRelShuttle (context .rexBuilder , true ));
311330
@@ -346,7 +365,6 @@ public SqlNode visit(SqlIdentifier id) {
346365 return super .visit (call );
347366 }
348367 });
349-
350368 SqlValidator validator = context .getValidator ();
351369 if (rewritten != null ) {
352370 try {
@@ -361,6 +379,9 @@ public SqlNode visit(SqlIdentifier id) {
361379 return relNode ;
362380 }
363381
382+ // if (rewritten instanceof SqlSelect select) {
383+ // rewritten = rewriteGroupBy(select);
384+ // }
364385 // Convert the validated SqlNode back to RelNode
365386 RelOptTable .ViewExpander viewExpander = context .config .getViewExpander ();
366387 RelOptCluster cluster = context .relBuilder .getCluster ();
@@ -463,4 +484,35 @@ private static RelNode convertToCalcitePlan(RelNode osPlan) {
463484 }
464485 return calcitePlan ;
465486 }
487+
488+ private SqlNode rewriteGroupBy (SqlSelect root ) {
489+ if (root .getGroup () == null ) {
490+ return root ;
491+ }
492+ List <SqlNode > selectList = root .getSelectList ().getList ();
493+ List <SqlNode > groupByList = root .getGroup ().getList ();
494+ List <SqlNode > unwrappedGroupByList = groupByList .stream ().map (QueryService ::unwrapAs ).toList ();
495+ List <SqlNode > unwrappedSelectList = selectList .stream ().map (QueryService ::unwrapAs ).toList ();
496+ if (new HashSet <>(unwrappedSelectList ).containsAll (unwrappedGroupByList )) {
497+ List <Integer > ordinals =
498+ unwrappedGroupByList .stream ().map (unwrappedSelectList ::indexOf ).toList ();
499+ List <SqlNode > groupByOrdinals =
500+ ordinals .stream ()
501+ .map (
502+ ordinal ->
503+ (SqlNode )
504+ SqlLiteral .createExactNumeric (
505+ Integer .toString (ordinal + 1 ), SqlParserPos .ZERO ))
506+ .collect (Collectors .toCollection (ArrayList ::new ));
507+ root .setGroupBy (SqlNodeList .of (root .getGroup ().getParserPosition (), groupByOrdinals ));
508+ }
509+ return root ;
510+ }
511+
512+ private static SqlNode unwrapAs (SqlNode node ) {
513+ if (node .getKind () == SqlKind .AS && node instanceof SqlCall ) {
514+ return ((SqlCall ) node ).getOperandList ().get (0 );
515+ }
516+ return node ;
517+ }
466518}
0 commit comments