Skip to content

Commit 7d61282

Browse files
committed
refactor: add in time type for duckdb casting in contract and fix spark test involving regexp for spark casting in contract
1 parent 78fa448 commit 7d61282

3 files changed

Lines changed: 13 additions & 7 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -329,13 +329,14 @@ def _ddb_safely_quote_name(field_name: str) -> str:
329329
return f'"{field_name}"'
330330

331331

332-
# pylint: disable=R0801,R0911
332+
# pylint: disable=R0801,R0911,R0912
333333
def get_duckdb_cast_statement_from_annotation(
334334
element_name: str,
335335
type_annotation: Any,
336336
parent_element: bool = True,
337337
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
338338
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
339+
time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$",
339340
) -> str:
340341
"""Generate casting statements for duckdb relations from type annotations"""
341342
type_origin = get_origin(type_annotation)
@@ -402,11 +403,15 @@ def get_duckdb_cast_statement_from_annotation(
402403
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
403404

404405
for type_ in type_annotation.mro():
406+
# datetime is subclass of date, so needs to be handled first
405407
if issubclass(type_, datetime):
406-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
408+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
407409
return stmt
408410
if issubclass(type_, date):
409-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
411+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
412+
return stmt
413+
if issubclass(type_, time):
414+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
410415
return stmt
411416
duck_type = get_duckdb_type_from_annotation(type_)
412417
if duck_type:

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,7 @@ def get_spark_cast_statement_from_annotation(
461461
type_annotation: Any,
462462
parent_element: bool = True,
463463
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
464-
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
464+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
465465
):
466466
"""Generate casting statements for spark dataframes based on type annotations"""
467467
type_origin = get_origin(type_annotation)
@@ -527,11 +527,12 @@ def get_spark_cast_statement_from_annotation(
527527
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
528528

529529
for type_ in type_annotation.mro():
530+
# datetime is subclass of date, so needs to be handled first
530531
if issubclass(type_, dt.datetime):
531-
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
532+
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
532533
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
533534
if issubclass(type_, dt.date):
534-
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
535+
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
535536
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
536537
spark_type = get_type_from_annotation(type_)
537538
if spark_type:

tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any):
247247
[("str_test", str, "trim(`str_test`)", StringType()),
248248
("int_test", int, "trim(`int_test`)", LongType()),
249249
("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()),
250-
("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()),
250+
("timestamp_test", dt.datetime, r"CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()),
251251
("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)),
252252
("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])),
253253
("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))])

0 commit comments

Comments
 (0)