Skip to content

Commit 5248af1

Browse files
committed
Override deriveType in validator to allow type checking on UDTs
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent e5d9984 commit 5248af1

1 file changed

Lines changed: 79 additions & 0 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/validate/PplValidator.java

Lines changed: 79 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,25 @@
55

66
package org.opensearch.sql.calcite.validate;
77

8+
import java.util.List;
9+
import java.util.function.Function;
10+
import org.apache.calcite.rel.type.RelDataType;
811
import org.apache.calcite.rel.type.RelDataTypeFactory;
12+
import org.apache.calcite.rel.type.RelDataTypeField;
13+
import org.apache.calcite.rel.type.RelRecordType;
14+
import org.apache.calcite.sql.SqlNode;
915
import org.apache.calcite.sql.SqlOperatorTable;
16+
import org.apache.calcite.sql.type.SqlTypeName;
1017
import org.apache.calcite.sql.validate.SqlValidatorCatalogReader;
1118
import org.apache.calcite.sql.validate.SqlValidatorImpl;
19+
import org.apache.calcite.sql.validate.SqlValidatorScope;
20+
import org.checkerframework.checker.nullness.qual.Nullable;
21+
import org.opensearch.sql.calcite.type.AbstractExprRelDataType;
22+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
23+
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT;
24+
import org.opensearch.sql.calcite.utils.UserDefinedFunctionUtils;
25+
import org.opensearch.sql.data.type.ExprCoreType;
26+
import org.opensearch.sql.data.type.ExprType;
1227

1328
/**
1429
* Custom SQL validator for PPL queries.
@@ -33,4 +48,68 @@ protected PplValidator(
3348
Config config) {
3449
super(opTab, catalogReader, typeFactory, config);
3550
}
51+
52+
/**
53+
* Overrides the deriveType method to map user-defined types (UDTs) to SqlTypes so that they can
54+
* be validated
55+
*/
56+
@Override
57+
public RelDataType deriveType(SqlValidatorScope scope, SqlNode expr) {
58+
RelDataType type = super.deriveType(scope, expr);
59+
return userDefinedTypeToSqlType(type);
60+
}
61+
62+
@Override
63+
public @Nullable RelDataType getValidatedNodeTypeIfKnown(SqlNode node) {
64+
RelDataType type = super.getValidatedNodeTypeIfKnown(node);
65+
return sqlTypeToUserDefinedType(type);
66+
}
67+
68+
private RelDataType userDefinedTypeToSqlType(RelDataType type) {
69+
return convertType(
70+
type,
71+
t -> {
72+
if (OpenSearchTypeFactory.isUserDefinedType(t)) {
73+
AbstractExprRelDataType<?> exprType = (AbstractExprRelDataType<?>) t;
74+
ExprType udtType = exprType.getExprType();
75+
OpenSearchTypeFactory typeFactory = (OpenSearchTypeFactory) this.getTypeFactory();
76+
return switch (udtType) {
77+
case ExprCoreType.TIMESTAMP ->
78+
typeFactory.createSqlType(SqlTypeName.TIMESTAMP, t.isNullable());
79+
case ExprCoreType.TIME -> typeFactory.createSqlType(SqlTypeName.TIME, t.isNullable());
80+
case ExprCoreType.DATE -> typeFactory.createSqlType(SqlTypeName.DATE, t.isNullable());
81+
case ExprCoreType.BINARY ->
82+
typeFactory.createSqlType(SqlTypeName.BINARY, t.isNullable());
83+
case ExprCoreType.IP -> UserDefinedFunctionUtils.NULLABLE_IP_UDT;
84+
default -> t;
85+
};
86+
}
87+
return t;
88+
});
89+
}
90+
91+
private RelDataType sqlTypeToUserDefinedType(RelDataType type) {
92+
return convertType(
93+
type,
94+
t -> {
95+
OpenSearchTypeFactory typeFactory = (OpenSearchTypeFactory) this.getTypeFactory();
96+
return switch (t.getSqlTypeName()) {
97+
case TIMESTAMP -> typeFactory.createUDT(ExprUDT.EXPR_TIMESTAMP, t.isNullable());
98+
case TIME -> typeFactory.createUDT(ExprUDT.EXPR_TIME, t.isNullable());
99+
case DATE -> typeFactory.createUDT(ExprUDT.EXPR_DATE, t.isNullable());
100+
case BINARY -> typeFactory.createUDT(ExprUDT.EXPR_BINARY, t.isNullable());
101+
default -> t;
102+
};
103+
});
104+
}
105+
106+
private RelDataType convertType(RelDataType type, Function<RelDataType, RelDataType> convert) {
107+
if (type == null) return null;
108+
if (type instanceof RelRecordType recordType) {
109+
List<RelDataType> subTypes =
110+
recordType.getFieldList().stream().map(RelDataTypeField::getType).map(convert).toList();
111+
return typeFactory.createStructType(subTypes, recordType.getFieldNames());
112+
}
113+
return convert.apply(type);
114+
}
36115
}

0 commit comments

Comments
 (0)