Skip to content

Commit 611efca

Browse files
committed
Override coerceOperandType and castTo methods to apply casting to udf logics
Signed-off-by: Yuanchun Shen <yuanchu@amazon.com>
1 parent 04fcd46 commit 611efca

6 files changed

Lines changed: 131 additions & 15 deletions

File tree

core/src/main/java/org/opensearch/sql/calcite/utils/OpenSearchTypeFactory.java

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,6 @@
55

66
package org.opensearch.sql.calcite.utils;
77

8-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_DATE;
9-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIME;
10-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP;
118
import static org.opensearch.sql.data.type.ExprCoreType.ARRAY;
129
import static org.opensearch.sql.data.type.ExprCoreType.BINARY;
1310
import static org.opensearch.sql.data.type.ExprCoreType.BOOLEAN;
@@ -385,8 +382,9 @@ public static boolean isNumericType(RelDataType fieldType) {
385382
* @param fieldType the RelDataType to check
386383
* @return true if the type is time-based, false otherwise
387384
*/
388-
public static boolean isTimeBasedType(RelDataType fieldType) {
385+
public static boolean isDatetime(RelDataType fieldType) {
389386
// Check standard SQL time types
387+
// TODO: Optimize with SqlTypeUtil.isDatetime
390388
SqlTypeName sqlType = fieldType.getSqlTypeName();
391389
if (sqlType == SqlTypeName.TIMESTAMP
392390
|| sqlType == SqlTypeName.TIMESTAMP_WITH_LOCAL_TIME_ZONE
@@ -408,4 +406,12 @@ public static boolean isTimeBasedType(RelDataType fieldType) {
408406
// Fallback check if type string contains EXPR_TIMESTAMP
409407
return fieldType.toString().contains("EXPR_TIMESTAMP");
410408
}
409+
410+
/**
411+
* This method should be used in place for {@link SqlTypeUtil#isCharacter(RelDataType)} because
412+
* user-defined types also have VARCHAR as their SqlTypeName.
413+
*/
414+
public static boolean isCharacter(RelDataType type) {
415+
return !isUserDefinedType(type) && SqlTypeUtil.isCharacter(type);
416+
}
411417
}

core/src/main/java/org/opensearch/sql/calcite/utils/binning/BinnableField.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ public BinnableField(RexNode fieldExpr, RelDataType fieldType, String fieldName)
3333
this.fieldType = fieldType;
3434
this.fieldName = fieldName;
3535

36-
this.isTimeBased = OpenSearchTypeFactory.isTimeBasedType(fieldType);
36+
this.isTimeBased = OpenSearchTypeFactory.isDatetime(fieldType);
3737
this.isNumeric = OpenSearchTypeFactory.isNumericType(fieldType);
3838

3939
// Reject truly unsupported types (e.g., BOOLEAN, ARRAY, MAP)

core/src/main/java/org/opensearch/sql/calcite/validate/PplTypeCoercion.java

Lines changed: 101 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,23 @@
55

66
package org.opensearch.sql.calcite.validate;
77

8+
import static java.util.Objects.requireNonNull;
89
import static org.opensearch.sql.calcite.validate.ValidationUtils.createUDTWithAttributes;
910

1011
import java.util.List;
1112
import java.util.Map;
1213
import java.util.Set;
1314
import java.util.stream.IntStream;
15+
import org.apache.calcite.adapter.java.JavaTypeFactory;
1416
import org.apache.calcite.rel.type.RelDataType;
1517
import org.apache.calcite.rel.type.RelDataTypeFactory;
18+
import org.apache.calcite.rel.type.RelDataTypeFactoryImpl;
19+
import org.apache.calcite.sql.SqlCall;
1620
import org.apache.calcite.sql.SqlCallBinding;
21+
import org.apache.calcite.sql.SqlDynamicParam;
1722
import org.apache.calcite.sql.SqlNode;
23+
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
24+
import org.apache.calcite.sql.type.SqlTypeCoercionRule;
1825
import org.apache.calcite.sql.type.SqlTypeFamily;
1926
import org.apache.calcite.sql.type.SqlTypeMappingRule;
2027
import org.apache.calcite.sql.type.SqlTypeName;
@@ -24,6 +31,9 @@
2431
import org.apache.calcite.sql.validate.implicit.TypeCoercionImpl;
2532
import org.checkerframework.checker.nullness.qual.Nullable;
2633
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
34+
import org.opensearch.sql.data.type.ExprCoreType;
35+
import org.opensearch.sql.data.type.ExprType;
36+
import org.opensearch.sql.expression.function.PPLBuiltinOperators;
2737

2838
/**
2939
* Custom type coercion implementation for PPL that extends Calcite's default type coercion with
@@ -86,13 +96,16 @@ private boolean isBlacklistedCoercion(RelDataType operandType, SqlTypeFamily exp
8696
@Override
8797
public @Nullable RelDataType implicitCast(RelDataType in, SqlTypeFamily expected) {
8898
RelDataType casted = super.implicitCast(in, expected);
99+
if (casted == null) {
100+
// String -> DATETIME is converted to String -> TIMESTAMP
101+
if (OpenSearchTypeFactory.isCharacter(in) && expected == SqlTypeFamily.DATETIME) {
102+
return createUDTWithAttributes(factory, in, OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP);
103+
}
104+
return null;
105+
}
89106
return switch (casted.getSqlTypeName()) {
90-
case SqlTypeName.DATE ->
91-
createUDTWithAttributes(factory, in, OpenSearchTypeFactory.ExprUDT.EXPR_DATE);
92-
case SqlTypeName.TIME ->
93-
createUDTWithAttributes(factory, in, OpenSearchTypeFactory.ExprUDT.EXPR_TIME);
94-
case SqlTypeName.TIMESTAMP ->
95-
createUDTWithAttributes(factory, in, OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP);
107+
case SqlTypeName.DATE, SqlTypeName.TIME, SqlTypeName.TIMESTAMP, SqlTypeName.BINARY ->
108+
createUDTWithAttributes(factory, in, casted.getSqlTypeName());
96109
default -> casted;
97110
};
98111
}
@@ -106,9 +119,90 @@ protected boolean needToCast(
106119
SqlValidatorScope scope, SqlNode node, RelDataType toType, SqlTypeMappingRule mappingRule) {
107120
boolean need = super.needToCast(scope, node, toType, mappingRule);
108121
RelDataType fromType = validator.deriveType(scope, node);
109-
if (OpenSearchTypeFactory.isUserDefinedType(toType) && SqlTypeUtil.isCharacter(fromType)) {
122+
if (OpenSearchTypeFactory.isUserDefinedType(toType)
123+
&& OpenSearchTypeFactory.isCharacter(fromType)) {
110124
need = true;
111125
}
112126
return need;
113127
}
128+
129+
@Override
130+
protected boolean dateTimeStringEquality(
131+
SqlCallBinding binding, RelDataType left, RelDataType right) {
132+
if (OpenSearchTypeFactory.isCharacter(left) && OpenSearchTypeFactory.isDatetime(right)) {
133+
// Use user-defined types in place of inbuilt datetime types
134+
RelDataType r =
135+
OpenSearchTypeFactory.isUserDefinedType(right)
136+
? right
137+
: ValidationUtils.createUDTWithAttributes(factory, right, right.getSqlTypeName());
138+
return coerceOperandType(binding.getScope(), binding.getCall(), 0, r);
139+
}
140+
if (OpenSearchTypeFactory.isCharacter(right) && OpenSearchTypeFactory.isDatetime(left)) {
141+
RelDataType l =
142+
OpenSearchTypeFactory.isUserDefinedType(left)
143+
? left
144+
: ValidationUtils.createUDTWithAttributes(factory, left, left.getSqlTypeName());
145+
return coerceOperandType(binding.getScope(), binding.getCall(), 1, l);
146+
}
147+
return false;
148+
}
149+
150+
@Override
151+
protected @Nullable RelDataType commonTypeForComparison(List<RelDataType> dataTypes) {
152+
return super.commonTypeForComparison(dataTypes);
153+
}
154+
155+
/**
156+
* Cast operand at index {@code index} to target type. we do this base on the fact that validate
157+
* happens before type coercion.
158+
*/
159+
protected boolean coerceOperandType(
160+
@Nullable SqlValidatorScope scope, SqlCall call, int index, RelDataType targetType) {
161+
// Transform the JavaType to SQL type because the SqlDataTypeSpec
162+
// does not support deriving JavaType yet.
163+
if (RelDataTypeFactoryImpl.isJavaType(targetType)) {
164+
targetType = ((JavaTypeFactory) factory).toSql(targetType);
165+
}
166+
167+
SqlNode operand = call.getOperandList().get(index);
168+
if (operand instanceof SqlDynamicParam) {
169+
// Do not support implicit type coercion for dynamic param.
170+
return false;
171+
}
172+
requireNonNull(scope, "scope");
173+
RelDataType operandType = validator.deriveType(scope, operand);
174+
if (coerceStringToArray(call, operand, index, operandType, targetType)) {
175+
return true;
176+
}
177+
178+
// Check it early.
179+
if (!needToCast(scope, operand, targetType, SqlTypeCoercionRule.lenientInstance())) {
180+
return false;
181+
}
182+
// Fix up nullable attr.
183+
RelDataType targetType1 = ValidationUtils.syncAttributes(factory, operandType, targetType);
184+
SqlNode desired = castTo(operand, targetType1);
185+
call.setOperand(index, desired);
186+
updateInferredType(desired, targetType1);
187+
return true;
188+
}
189+
190+
private static SqlNode castTo(SqlNode node, RelDataType type) {
191+
if (OpenSearchTypeFactory.isDatetime(type)) {
192+
ExprType exprType = OpenSearchTypeFactory.convertRelDataTypeToExprType(type);
193+
return switch (exprType) {
194+
case ExprCoreType.DATE ->
195+
PPLBuiltinOperators.DATE.createCall(node.getParserPosition(), node);
196+
case ExprCoreType.TIMESTAMP ->
197+
PPLBuiltinOperators.TIMESTAMP.createCall(node.getParserPosition(), node);
198+
case ExprCoreType.TIME ->
199+
PPLBuiltinOperators.TIME.createCall(node.getParserPosition(), node);
200+
default -> throw new UnsupportedOperationException("Unsupported type: " + exprType);
201+
};
202+
}
203+
return SqlStdOperatorTable.CAST.createCall(
204+
node.getParserPosition(),
205+
node,
206+
SqlTypeUtil.convertTypeToSpec(type).withNullable(type.isNullable()));
207+
}
114208
}

core/src/main/java/org/opensearch/sql/calcite/validate/ValidationUtils.java

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import org.apache.calcite.rel.type.RelDataType;
1313
import org.apache.calcite.rel.type.RelDataTypeFactory;
1414
import org.apache.calcite.sql.SqlCollation;
15+
import org.apache.calcite.sql.type.SqlTypeName;
1516
import org.apache.calcite.sql.type.SqlTypeUtil;
1617
import org.opensearch.sql.calcite.utils.OpenSearchTypeFactory;
1718

@@ -57,4 +58,19 @@ public static RelDataType createUDTWithAttributes(
5758
RelDataType type = typeFactory.createUDT(userDefinedType);
5859
return syncAttributes(typeFactory, fromType, type);
5960
}
61+
62+
public static RelDataType createUDTWithAttributes(
63+
RelDataTypeFactory factory, RelDataType fromType, SqlTypeName sqlTypeName) {
64+
return switch (sqlTypeName) {
65+
case SqlTypeName.DATE ->
66+
createUDTWithAttributes(factory, fromType, OpenSearchTypeFactory.ExprUDT.EXPR_DATE);
67+
case SqlTypeName.TIME ->
68+
createUDTWithAttributes(factory, fromType, OpenSearchTypeFactory.ExprUDT.EXPR_TIME);
69+
case SqlTypeName.TIMESTAMP ->
70+
createUDTWithAttributes(factory, fromType, OpenSearchTypeFactory.ExprUDT.EXPR_TIMESTAMP);
71+
case SqlTypeName.BINARY ->
72+
createUDTWithAttributes(factory, fromType, OpenSearchTypeFactory.ExprUDT.EXPR_BINARY);
73+
default -> throw new IllegalArgumentException("Unsupported type: " + sqlTypeName);
74+
};
75+
}
6076
}

core/src/main/java/org/opensearch/sql/expression/function/UDFOperandMetadata.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ public List<String> paramNames() {
4848

4949
@Override
5050
public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
51-
return typeChecker.checkOperandTypesWithoutTypeCoercion(callBinding, throwOnFailure);
51+
return typeChecker.checkOperandTypes(callBinding, throwOnFailure);
5252
}
5353

5454
@Override

opensearch/src/main/java/org/opensearch/sql/opensearch/planner/rules/AggregateIndexScanRule.java

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
package org.opensearch.sql.opensearch.planner.rules;
77

8-
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.isTimeBasedType;
8+
import static org.opensearch.sql.calcite.utils.OpenSearchTypeFactory.isDatetime;
99
import static org.opensearch.sql.expression.function.PPLBuiltinOperators.WIDTH_BUCKET;
1010

1111
import java.util.List;
@@ -211,7 +211,7 @@ public interface Config extends OpenSearchRuleConfig {
211211
agg.getGroupSet().stream()
212212
.allMatch(
213213
group ->
214-
isTimeBasedType(
214+
isDatetime(
215215
agg.getInput().getRowType().getFieldList().get(group).getType()));
216216

217217
Config BUCKET_NON_NULL_AGG =

0 commit comments

Comments
 (0)