Skip to content

Commit 414d3ac

Browse files
committed
feat: integrated duckdb casting into data contract and added initial spark casting
1 parent 9ce1084 commit 414d3ac

4 files changed

Lines changed: 110 additions & 34 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 5 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
generate_error_casting_entity_message,
2929
)
3030
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
31+
get_duckdb_cast_statement_from_annotation,
3132
duckdb_read_parquet,
3233
duckdb_record_index,
3334
duckdb_write_parquet,
@@ -101,17 +102,6 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
101102
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
102103
return self._connection.sql("select * from _lazy_df")
103104

104-
@staticmethod
105-
def generate_ddb_cast_statement(
106-
column_name: str, dtype: DuckDBPyType, null_flag: bool = False
107-
) -> str:
108-
"""Helper method to generate sql statements for casting datatypes (permissively).
109-
Current duckdb python API doesn't play well with this currently.
110-
"""
111-
if not null_flag:
112-
return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"'
113-
return f'cast(NULL AS {dtype}) AS "{column_name}"'
114-
115105
# pylint: disable=R0914
116106
def apply_data_contract(
117107
self,
@@ -180,12 +170,13 @@ def apply_data_contract(
180170

181171
casting_statements = [
182172
(
183-
self.generate_ddb_cast_statement(column, dtype)
173+
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + f""" AS "{column}" """
184174
if column in relation.columns
185-
else self.generate_ddb_cast_statement(column, dtype, null_flag=True)
175+
else f"CAST(NULL AS {ddb_schema[column]}) AS {column}"
186176
)
187-
for column, dtype in ddb_schema.items()
177+
for column, mdl_fld in entity_fields.items()
188178
]
179+
casting_statements.append(f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}")
189180
try:
190181
relation = relation.project(", ".join(casting_statements))
191182
except Exception as err: # pylint: disable=broad-except

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -315,30 +315,39 @@ def duckdb_record_index(cls):
315315
return cls
316316

317317
def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str:
318-
return f"try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"
318+
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""
319+
320+
def _ddb_safely_quote_name(field_name:str) -> str:
321+
try:
322+
sep_idx = field_name.rindex(".")
323+
return field_name[:sep_idx + 1] + f"\"{field_name[sep_idx + 1:]}\""
324+
except ValueError:
325+
return f"\"{field_name}\""
319326

320327

321328
def get_duckdb_cast_statement_from_annotation(element_name:str,
322329
type_annotation: Any,
323330
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
324331
timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$",
325-
include_cast: bool = True) -> DuckDBPyType:
332+
parent_element: bool = True) -> str:
326333
type_origin = get_origin(type_annotation)
334+
335+
quoted_name = _ddb_safely_quote_name(element_name)
327336

328337
# An `Optional` or `Union` type, check to ensure non-heterogenity.
329338
if type_origin is Union:
330339
python_type = _get_non_heterogenous_type(get_args(type_annotation))
331-
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast)
340+
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element)
332341

333342
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
334343
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
335344
element_type = _get_non_heterogenous_type(get_args(type_annotation))
336-
stmt = f"list_transform({element_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})"
337-
return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation)
345+
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})"
346+
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
338347

339348
if type_origin is Annotated:
340349
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
341-
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) # add other expected params here
350+
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) # add other expected params here
342351
# Ensure that we have a concrete type at this point.
343352
if not isinstance(type_annotation, type):
344353
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
@@ -372,9 +381,9 @@ def get_duckdb_cast_statement_from_annotation(element_name:str,
372381
raise ValueError(
373382
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
374383
)
375-
cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()])
384+
cast_exprs = ",".join([f"\"{nme}\":= {stmt}" for nme, stmt in fields.items()])
376385
stmt = f"struct_pack({cast_exprs})"
377-
return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation)
386+
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
378387

379388
if type_annotation is list:
380389
raise ValueError(
@@ -385,13 +394,13 @@ def get_duckdb_cast_statement_from_annotation(element_name:str,
385394

386395
for type_ in type_annotation.mro():
387396
if issubclass(type_, datetime):
388-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({element_name}) as TIMESTAMP) ELSE NULL END"
397+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END"
389398
return stmt
390399
if issubclass(type_, date):
391-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{date_regex}') THEN TRY_CAST(TRIM({element_name}) as DATE) ELSE NULL END"
400+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END"
392401
return stmt
393402
duck_type = get_duckdb_type_from_annotation(type_)
394403
if duck_type:
395-
stmt = f"trim({element_name})"
396-
return _cast_as_ddb_type(stmt, type_) if include_cast else stmt
404+
stmt = f"trim({quoted_name})"
405+
return _cast_as_ddb_type(stmt, type_) if parent_element else stmt
397406
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,76 @@ def spark_record_index(cls):
439439
setattr(cls, "add_record_index", _add_spark_record_index)
440440
setattr(cls, "drop_record_index", _drop_spark_record_index)
441441
return cls
442+
443+
444+
def _cast_as_spark_type(field_expr:str, field_type: st.DataType) -> Column:
445+
return sf.expr(field_expr).cast(field_type)
446+
447+
448+
def get_spark_cast_statement_from_annotation(element_name:str,
449+
type_annotation: Any,
450+
include_cast: bool = True) -> st.DataType:
451+
type_origin = get_origin(type_annotation)
452+
453+
# An `Optional` or `Union` type, check to ensure non-heterogenity.
454+
if type_origin is Union:
455+
python_type = _get_non_heterogenous_type(get_args(type_annotation))
456+
return get_spark_cast_statement_from_annotation(element_name, include_cast)
457+
458+
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
459+
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
460+
element_type = _get_non_heterogenous_type(get_args(type_annotation))
461+
stmt = f"transform({element_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False)})"
462+
return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation)
463+
464+
if type_origin is Annotated:
465+
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
466+
return get_spark_cast_statement_from_annotation(element_name, python_type, include_cast) # add other expected params here
467+
# Ensure that we have a concrete type at this point.
468+
if not isinstance(type_annotation, type):
469+
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
470+
471+
if (
472+
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
473+
(issubclass(type_annotation, dict) and type_annotation is not dict)
474+
# Type hint is a dataclass.
475+
or is_dataclass(type_annotation)
476+
# Type hint is a `pydantic` model.
477+
or (type_origin is None and issubclass(type_annotation, BaseModel))
478+
):
479+
fields: dict[str, str] = {}
480+
for field_name, field_annotation in get_type_hints(type_annotation).items():
481+
# Technically non-string keys are disallowed, but people are bad.
482+
if not isinstance(field_name, str):
483+
raise ValueError(
484+
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
485+
) # pragma: no cover
486+
if get_origin(field_annotation) is ClassVar:
487+
continue
488+
489+
fields[field_name] = get_spark_cast_statement_from_annotation(
490+
f"{element_name}.{field_name}",
491+
field_annotation,
492+
False)
493+
494+
if not fields:
495+
raise ValueError(
496+
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
497+
)
498+
cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()])
499+
stmt = f"struct_pack({cast_exprs})"
500+
return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation)
501+
502+
if type_annotation is list:
503+
raise ValueError(
504+
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
505+
)
506+
if type_annotation is dict or type_origin is dict:
507+
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
508+
509+
for type_ in type_annotation.mro():
510+
duck_type = get_type_from_annotation(type_)
511+
if duck_type:
512+
stmt = f"trim({element_name})"
513+
return _cast_as_spark_type(stmt, type_) if include_cast else stmt
514+
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py

Lines changed: 11 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -154,13 +154,13 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
154154

155155
# add decimal check
156156
@pytest.mark.parametrize("field_name,field_type,cast_statement",
157-
[("str_test", str, "try_cast(trim(str_test) as VARCHAR)"),
158-
("int_test", int, "try_cast(trim(int_test) as BIGINT)"),
159-
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(date_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(date_test) as DATE) ELSE NULL END"),
160-
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(timestamp_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(timestamp_test) as TIMESTAMP) ELSE NULL END"),
161-
("list_int_field", list[int], "try_cast(list_transform(list_int_field, x -> trim(x)) as BIGINT[])"),
162-
("basic_model", BasicModel, "try_cast(struct_pack(str_field:= trim(basic_model.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163-
("another_model", AnotherModel, "try_cast(struct_pack(unique_id:= trim(another_model.unique_id),basic_models:= list_transform(another_model.basic_models, x -> struct_pack(str_field:= trim(x.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
157+
[("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"),
158+
("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"),
159+
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"),
160+
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
161+
("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"),
162+
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(basic_model.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.\"date_field\") as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163+
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(another_model.\"unique_id\"),\"basic_models\":= list_transform(another_model.\"basic_models\", x -> struct_pack(\"str_field\":= trim(x.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(x.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.\"date_field\") as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
164164
def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement):
165165
assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement
166166

@@ -172,5 +172,8 @@ def test_use_cast_statements(casting_test_table):
172172
test_rel = test_rel.project(",".join(casting_statements))
173173
assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()}
174174
dodgy_date_rec = test_rel.pl()[1].to_dicts()[0]
175-
assert not dodgy_date_rec.get("date_test") and not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
175+
assert (not dodgy_date_rec.get("date_test") and
176+
not dodgy_date_rec.get("basic_model",{}).get("date_field")
177+
and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
178+
)
176179

0 commit comments

Comments
 (0)