Skip to content

Commit f9f4d4f

Browse files
committed
refactor: added further spark cast work with tests, small fixes to duckdb casting
1 parent 414d3ac commit f9f4d4f

6 files changed

Lines changed: 147 additions & 59 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,10 @@
2828
generate_error_casting_entity_message,
2929
)
3030
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
31-
get_duckdb_cast_statement_from_annotation,
3231
duckdb_read_parquet,
3332
duckdb_record_index,
3433
duckdb_write_parquet,
34+
get_duckdb_cast_statement_from_annotation,
3535
get_duckdb_type_from_annotation,
3636
relation_is_empty,
3737
)
@@ -102,7 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
102102
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
103103
return self._connection.sql("select * from _lazy_df")
104104

105-
# pylint: disable=R0914
105+
# pylint: disable=R0914,R0915
106106
def apply_data_contract(
107107
self,
108108
working_dir: URI,
@@ -170,13 +170,16 @@ def apply_data_contract(
170170

171171
casting_statements = [
172172
(
173-
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + f""" AS "{column}" """
173+
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation)
174+
+ f""" AS "{column}" """
174175
if column in relation.columns
175176
else f"CAST(NULL AS {ddb_schema[column]}) AS {column}"
176177
)
177178
for column, mdl_fld in entity_fields.items()
178179
]
179-
casting_statements.append(f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}")
180+
casting_statements.append(
181+
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
182+
)
180183
try:
181184
relation = relation.project(", ".join(casting_statements))
182185
except Exception as err: # pylint: disable=broad-except

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 32 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -314,40 +314,49 @@ def duckdb_record_index(cls):
314314
setattr(cls, "drop_record_index", _drop_duckdb_record_index)
315315
return cls
316316

317-
def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str:
317+
318+
def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str:
318319
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""
319320

320-
def _ddb_safely_quote_name(field_name:str) -> str:
321+
322+
def _ddb_safely_quote_name(field_name: str) -> str:
321323
try:
322-
sep_idx = field_name.rindex(".")
323-
return field_name[:sep_idx + 1] + f"\"{field_name[sep_idx + 1:]}\""
324+
sep_idx = field_name.index(".")
325+
return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:]
324326
except ValueError:
325-
return f"\"{field_name}\""
326-
327-
328-
def get_duckdb_cast_statement_from_annotation(element_name:str,
329-
type_annotation: Any,
330-
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
331-
timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$",
332-
parent_element: bool = True) -> str:
327+
return f'"{field_name}"'
328+
329+
# pylint: disable=R0911
330+
def get_duckdb_cast_statement_from_annotation(
331+
element_name: str,
332+
type_annotation: Any,
333+
parent_element: bool = True,
334+
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
335+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$",
336+
) -> str:
337+
"""Generate casting statements for duckdb relations from type annotations"""
333338
type_origin = get_origin(type_annotation)
334-
339+
335340
quoted_name = _ddb_safely_quote_name(element_name)
336341

337342
# An `Optional` or `Union` type, check to ensure non-heterogenity.
338343
if type_origin is Union:
339344
python_type = _get_non_heterogenous_type(get_args(type_annotation))
340-
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element)
345+
return get_duckdb_cast_statement_from_annotation(
346+
element_name, python_type, date_regex, timestamp_regex, parent_element
347+
)
341348

342349
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
343350
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
344351
element_type = _get_non_heterogenous_type(get_args(type_annotation))
345-
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})"
352+
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
346353
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
347354

348355
if type_origin is Annotated:
349356
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
350-
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) # add other expected params here
357+
return get_duckdb_cast_statement_from_annotation(
358+
element_name, python_type, date_regex, timestamp_regex, parent_element
359+
) # add other expected params here
351360
# Ensure that we have a concrete type at this point.
352361
if not isinstance(type_annotation, type):
353362
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
@@ -371,17 +380,14 @@ def get_duckdb_cast_statement_from_annotation(element_name:str,
371380
continue
372381

373382
fields[field_name] = get_duckdb_cast_statement_from_annotation(
374-
f"{element_name}.{field_name}",
375-
field_annotation,
376-
date_regex,
377-
timestamp_regex,
378-
False)
383+
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
384+
)
379385

380386
if not fields:
381387
raise ValueError(
382388
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
383389
)
384-
cast_exprs = ",".join([f"\"{nme}\":= {stmt}" for nme, stmt in fields.items()])
390+
cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()])
385391
stmt = f"struct_pack({cast_exprs})"
386392
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
387393

@@ -394,13 +400,13 @@ def get_duckdb_cast_statement_from_annotation(element_name:str,
394400

395401
for type_ in type_annotation.mro():
396402
if issubclass(type_, datetime):
397-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END"
403+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
398404
return stmt
399405
if issubclass(type_, date):
400-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END"
406+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
401407
return stmt
402-
duck_type = get_duckdb_type_from_annotation(type_)
408+
duck_type = get_duckdb_type_from_annotation(type_)
403409
if duck_type:
404-
stmt = f"trim({quoted_name})"
410+
stmt = f"trim({quoted_name})"
405411
return _cast_as_ddb_type(stmt, type_) if parent_element else stmt
406412
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 40 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -441,29 +441,41 @@ def spark_record_index(cls):
441441
return cls
442442

443443

444-
def _cast_as_spark_type(field_expr:str, field_type: st.DataType) -> Column:
445-
return sf.expr(field_expr).cast(field_type)
444+
def _cast_as_spark_type(field_expr: str, field_type: st.DataType) -> Column:
445+
return sf.expr(field_expr).cast(get_type_from_annotation(field_type))
446446

447-
448-
def get_spark_cast_statement_from_annotation(element_name:str,
449-
type_annotation: Any,
450-
include_cast: bool = True) -> st.DataType:
447+
def _spark_safely_quote_name(field_name: str) -> str:
448+
try:
449+
sep_idx = field_name.index(".")
450+
return f'`{field_name[: sep_idx]}`' + field_name[sep_idx:]
451+
except ValueError:
452+
return f'`{field_name}`'
453+
454+
def get_spark_cast_statement_from_annotation(
455+
element_name: str, type_annotation: Any, parent_element: bool = True,
456+
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
457+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$"):
458+
"""Generate casting statements for spark dataframes based on type annotations"""
451459
type_origin = get_origin(type_annotation)
460+
461+
quoted_name = _spark_safely_quote_name(element_name)
452462

453463
# An `Optional` or `Union` type, check to ensure non-heterogenity.
454464
if type_origin is Union:
455465
python_type = _get_non_heterogenous_type(get_args(type_annotation))
456-
return get_spark_cast_statement_from_annotation(element_name, include_cast)
466+
return get_spark_cast_statement_from_annotation(element_name, python_type, parent_element, date_regex, timestamp_regex)
457467

458468
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
459469
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
460470
element_type = _get_non_heterogenous_type(get_args(type_annotation))
461-
stmt = f"transform({element_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False)})"
462-
return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation)
471+
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
472+
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
463473

464474
if type_origin is Annotated:
465-
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
466-
return get_spark_cast_statement_from_annotation(element_name, python_type, include_cast) # add other expected params here
475+
python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable
476+
return get_spark_cast_statement_from_annotation(
477+
element_name, python_type, parent_element, date_regex, timestamp_regex
478+
) # add other expected params here
467479
# Ensure that we have a concrete type at this point.
468480
if not isinstance(type_annotation, type):
469481
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
@@ -487,18 +499,16 @@ def get_spark_cast_statement_from_annotation(element_name:str,
487499
continue
488500

489501
fields[field_name] = get_spark_cast_statement_from_annotation(
490-
f"{element_name}.{field_name}",
491-
field_annotation,
492-
False)
502+
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
503+
)
493504

494505
if not fields:
495506
raise ValueError(
496507
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
497508
)
498-
cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()])
499-
stmt = f"struct_pack({cast_exprs})"
500-
return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation)
501-
509+
cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()])
510+
stmt = f"struct({cast_exprs})"
511+
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
502512
if type_annotation is list:
503513
raise ValueError(
504514
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
@@ -507,8 +517,15 @@ def get_spark_cast_statement_from_annotation(element_name:str,
507517
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
508518

509519
for type_ in type_annotation.mro():
510-
duck_type = get_type_from_annotation(type_)
511-
if duck_type:
512-
stmt = f"trim({element_name})"
513-
return _cast_as_spark_type(stmt, type_) if include_cast else stmt
514-
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")
520+
if issubclass(type_, dt.datetime):
521+
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
522+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
523+
elif issubclass(type_, dt.date):
524+
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
525+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
526+
else:
527+
spark_type = get_type_from_annotation(type_)
528+
if spark_type:
529+
stmt = f"trim({quoted_name})"
530+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
531+
raise ValueError(f"No equivalent Spark type for {type_annotation!r}")

src/dve/pipeline/foundry_ddb_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI:
4242
write_to.parent.mkdir(parents=True, exist_ok=True)
4343
write_to = write_to.as_posix()
4444
self.write_parquet( # type: ignore # pylint: disable=E1101
45-
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
45+
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
4646
f"submission_id = '{submission_info.submission_id}'"
4747
),
4848
fh.joinuri(write_to, "processing_status.parquet"),
4949
)
5050
self.write_parquet( # type: ignore # pylint: disable=E1101
51-
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
51+
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
5252
f"submission_id = '{submission_info.submission_id}'"
5353
),
5454
fh.joinuri(write_to, "submission_statistics.parquet"),

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -159,8 +159,8 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
159159
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"),
160160
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
161161
("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"),
162-
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(basic_model.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.\"date_field\") as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163-
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(another_model.\"unique_id\"),\"basic_models\":= list_transform(another_model.\"basic_models\", x -> struct_pack(\"str_field\":= trim(x.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(x.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.\"date_field\") as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
162+
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163+
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
164164
def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement):
165165
assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement
166166

0 commit comments

Comments
 (0)