Skip to content

Commit 3587a02

Browse files
authored
Feature/ndit 1146 improve contract casting statements (#90)
* fix: enhance duckdb casting to be less permissive of poorly formatted dates and trim whitespace * feat: integrated duckdb casting into data contract and added initial spark casting * refactor: added further spark cast work with tests, small fixes to duckdb casting * style: address linting issues * refactor: add in time type for duckdb casting in contract and fix spark test involving regexp for spark casting in contract
1 parent 2df211f commit 3587a02

6 files changed

Lines changed: 364 additions & 21 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
duckdb_read_parquet,
3232
duckdb_record_index,
3333
duckdb_write_parquet,
34+
get_duckdb_cast_statement_from_annotation,
3435
get_duckdb_type_from_annotation,
3536
relation_is_empty,
3637
)
@@ -101,18 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument
101102
_lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable
102103
return self._connection.sql("select * from _lazy_df")
103104

104-
@staticmethod
105-
def generate_ddb_cast_statement(
106-
column_name: str, dtype: DuckDBPyType, null_flag: bool = False
107-
) -> str:
108-
"""Helper method to generate sql statements for casting datatypes (permissively).
109-
Current duckdb python API doesn't play well with this currently.
110-
"""
111-
if not null_flag:
112-
return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"'
113-
return f'cast(NULL AS {dtype}) AS "{column_name}"'
114-
115-
# pylint: disable=R0914
105+
# pylint: disable=R0914,R0915
116106
def apply_data_contract(
117107
self,
118108
working_dir: URI,
@@ -180,12 +170,16 @@ def apply_data_contract(
180170

181171
casting_statements = [
182172
(
183-
self.generate_ddb_cast_statement(column, dtype)
173+
get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation)
174+
+ f""" AS "{column}" """
184175
if column in relation.columns
185-
else self.generate_ddb_cast_statement(column, dtype, null_flag=True)
176+
else f"CAST(NULL AS {ddb_schema[column]}) AS {column}"
186177
)
187-
for column, dtype in ddb_schema.items()
178+
for column, mdl_fld in entity_fields.items()
188179
]
180+
casting_statements.append(
181+
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
182+
)
189183
try:
190184
relation = relation.project(", ".join(casting_statements))
191185
except Exception as err: # pylint: disable=broad-except

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,108 @@ def duckdb_record_index(cls):
313313
setattr(cls, "add_record_index", _add_duckdb_record_index)
314314
setattr(cls, "drop_record_index", _drop_duckdb_record_index)
315315
return cls
316+
317+
318+
def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str:
319+
"""Cast to Duck DB type"""
320+
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""
321+
322+
323+
def _ddb_safely_quote_name(field_name: str) -> str:
324+
"""Quote field names in case reserved"""
325+
try:
326+
sep_idx = field_name.index(".")
327+
return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:]
328+
except ValueError:
329+
return f'"{field_name}"'
330+
331+
332+
# pylint: disable=R0801,R0911,R0912
333+
def get_duckdb_cast_statement_from_annotation(
334+
element_name: str,
335+
type_annotation: Any,
336+
parent_element: bool = True,
337+
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
338+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
339+
time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$",
340+
) -> str:
341+
"""Generate casting statements for duckdb relations from type annotations"""
342+
type_origin = get_origin(type_annotation)
343+
344+
quoted_name = _ddb_safely_quote_name(element_name)
345+
346+
# An `Optional` or `Union` type, check to ensure non-heterogenity.
347+
if type_origin is Union:
348+
python_type = _get_non_heterogenous_type(get_args(type_annotation))
349+
return get_duckdb_cast_statement_from_annotation(
350+
element_name, python_type, parent_element, date_regex, timestamp_regex
351+
)
352+
353+
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
354+
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
355+
element_type = _get_non_heterogenous_type(get_args(type_annotation))
356+
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
357+
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
358+
359+
if type_origin is Annotated:
360+
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
361+
return get_duckdb_cast_statement_from_annotation(
362+
element_name, python_type, parent_element, date_regex, timestamp_regex
363+
) # add other expected params here
364+
# Ensure that we have a concrete type at this point.
365+
if not isinstance(type_annotation, type):
366+
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
367+
368+
if (
369+
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
370+
(issubclass(type_annotation, dict) and type_annotation is not dict)
371+
# Type hint is a dataclass.
372+
or is_dataclass(type_annotation)
373+
# Type hint is a `pydantic` model.
374+
or (type_origin is None and issubclass(type_annotation, BaseModel))
375+
):
376+
fields: dict[str, str] = {}
377+
for field_name, field_annotation in get_type_hints(type_annotation).items():
378+
# Technically non-string keys are disallowed, but people are bad.
379+
if not isinstance(field_name, str):
380+
raise ValueError(
381+
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
382+
) # pragma: no cover
383+
if get_origin(field_annotation) is ClassVar:
384+
continue
385+
386+
fields[field_name] = get_duckdb_cast_statement_from_annotation(
387+
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
388+
)
389+
390+
if not fields:
391+
raise ValueError(
392+
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
393+
)
394+
cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()])
395+
stmt = f"struct_pack({cast_exprs})"
396+
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
397+
398+
if type_annotation is list:
399+
raise ValueError(
400+
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
401+
)
402+
if type_annotation is dict or type_origin is dict:
403+
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
404+
405+
for type_ in type_annotation.mro():
406+
# datetime is subclass of date, so needs to be handled first
407+
if issubclass(type_, datetime):
408+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
409+
return stmt
410+
if issubclass(type_, date):
411+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
412+
return stmt
413+
if issubclass(type_, time):
414+
stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301
415+
return stmt
416+
duck_type = get_duckdb_type_from_annotation(type_)
417+
if duck_type:
418+
stmt = f"trim({quoted_name})"
419+
return _cast_as_ddb_type(stmt, type_) if parent_element else stmt
420+
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,103 @@ def spark_record_index(cls):
439439
setattr(cls, "add_record_index", _add_spark_record_index)
440440
setattr(cls, "drop_record_index", _drop_spark_record_index)
441441
return cls
442+
443+
444+
def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column:
445+
"""Cast to spark type"""
446+
return sf.expr(field_expr).cast(get_type_from_annotation(field_type))
447+
448+
449+
def _spark_safely_quote_name(field_name: str) -> str:
450+
"""Quote field names in case reserved"""
451+
try:
452+
sep_idx = field_name.index(".")
453+
return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:]
454+
except ValueError:
455+
return f"`{field_name}`"
456+
457+
458+
# pylint: disable=R0801
459+
def get_spark_cast_statement_from_annotation(
460+
element_name: str,
461+
type_annotation: Any,
462+
parent_element: bool = True,
463+
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
464+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
465+
):
466+
"""Generate casting statements for spark dataframes based on type annotations"""
467+
type_origin = get_origin(type_annotation)
468+
469+
quoted_name = _spark_safely_quote_name(element_name)
470+
471+
# An `Optional` or `Union` type, check to ensure non-heterogenity.
472+
if type_origin is Union:
473+
python_type = _get_non_heterogenous_type(get_args(type_annotation))
474+
return get_spark_cast_statement_from_annotation(
475+
element_name, python_type, parent_element, date_regex, timestamp_regex
476+
)
477+
478+
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
479+
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
480+
element_type = _get_non_heterogenous_type(get_args(type_annotation))
481+
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
482+
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
483+
484+
if type_origin is Annotated:
485+
python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable
486+
return get_spark_cast_statement_from_annotation(
487+
element_name, python_type, parent_element, date_regex, timestamp_regex
488+
) # add other expected params here
489+
# Ensure that we have a concrete type at this point.
490+
if not isinstance(type_annotation, type):
491+
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
492+
493+
if (
494+
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
495+
(issubclass(type_annotation, dict) and type_annotation is not dict)
496+
# Type hint is a dataclass.
497+
or is_dataclass(type_annotation)
498+
# Type hint is a `pydantic` model.
499+
or (type_origin is None and issubclass(type_annotation, BaseModel))
500+
):
501+
fields: dict[str, str] = {}
502+
for field_name, field_annotation in get_type_hints(type_annotation).items():
503+
# Technically non-string keys are disallowed, but people are bad.
504+
if not isinstance(field_name, str):
505+
raise ValueError(
506+
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
507+
) # pragma: no cover
508+
if get_origin(field_annotation) is ClassVar:
509+
continue
510+
511+
fields[field_name] = get_spark_cast_statement_from_annotation(
512+
f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex
513+
)
514+
515+
if not fields:
516+
raise ValueError(
517+
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
518+
)
519+
cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()])
520+
stmt = f"struct({cast_exprs})"
521+
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
522+
if type_annotation is list:
523+
raise ValueError(
524+
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
525+
)
526+
if type_annotation is dict or type_origin is dict:
527+
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
528+
529+
for type_ in type_annotation.mro():
530+
# datetime is subclass of date, so needs to be handled first
531+
if issubclass(type_, dt.datetime):
532+
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
533+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
534+
if issubclass(type_, dt.date):
535+
stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
536+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
537+
spark_type = get_type_from_annotation(type_)
538+
if spark_type:
539+
stmt = f"trim({quoted_name})"
540+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
541+
raise ValueError(f"No equivalent Spark type for {type_annotation!r}")

src/dve/pipeline/foundry_ddb_pipeline.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI:
4242
write_to.parent.mkdir(parents=True, exist_ok=True)
4343
write_to = write_to.as_posix()
4444
self.write_parquet( # type: ignore # pylint: disable=E1101
45-
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
45+
self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212
4646
f"submission_id = '{submission_info.submission_id}'"
4747
),
4848
fh.joinuri(write_to, "processing_status.parquet"),
4949
)
5050
self.write_parquet( # type: ignore # pylint: disable=E1101
51-
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
51+
self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212
5252
f"submission_id = '{submission_info.submission_id}'"
5353
),
5454
fh.joinuri(write_to, "submission_statistics.parquet"),

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py

Lines changed: 84 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,73 @@
33
import datetime
44
import tempfile
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, List
77

88
import pytest
99
import pyspark.sql.types as pst
1010
from duckdb import DuckDBPyRelation, DuckDBPyConnection
11+
from pydantic import BaseModel
1112
from pyspark.sql import Row, SparkSession
1213

1314
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
1415
_ddb_read_parquet,
15-
duckdb_rel_to_dictionaries)
16+
duckdb_rel_to_dictionaries,
17+
get_duckdb_cast_statement_from_annotation,
18+
get_duckdb_type_from_annotation)
1619

20+
@pytest.fixture
21+
def casting_test_table(temp_ddb_conn):
22+
_, conn = temp_ddb_conn
23+
conn.sql("""CREATE TABLE test_casting (
24+
str_test VARCHAR,
25+
int_test VARCHAR,
26+
date_test VARCHAR,
27+
timestamp_test VARCHAR,
28+
list_int_field VARCHAR[],
29+
basic_model STRUCT(str_field VARCHAR, date_field VARCHAR),
30+
another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""")
31+
32+
conn.sql("""INSERT INTO test_casting
33+
VALUES(
34+
'good_one',
35+
'1',
36+
'2024-11-13',
37+
'2024-04-15 12:25:36',
38+
['1', '2', '3'],
39+
{'str_field': 'test', 'date_field': '2024-12-11'},
40+
{'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}),
41+
(
42+
'dodgy_dates',
43+
'2',
44+
'24-11-13',
45+
'2024-4-15 12:25:36',
46+
['4', '5', '6'],
47+
{'str_field': 'test', 'date_field': '202-1-11'},
48+
{'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""")
49+
50+
51+
yield temp_ddb_conn
52+
53+
conn.sql("DROP TABLE IF EXISTS test_casting")
54+
55+
56+
57+
class BasicModel(BaseModel):
58+
str_field: str
59+
date_field: datetime.date
60+
61+
class AnotherModel(BaseModel):
62+
unique_id: int
63+
basic_models: List[BasicModel]
64+
65+
class CastingRecord(BaseModel):
66+
str_test: str
67+
int_test: int
68+
date_test: datetime.date
69+
timestamp_test: datetime.datetime
70+
list_int_field: list[int]
71+
basic_model: BasicModel
72+
another_model: AnotherModel
1773

1874
class TempConnection:
1975
"""
@@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None:
2581
self._connection = connection
2682

2783

84+
2885
@pytest.mark.parametrize(
2986
"outpath",
3087
[
@@ -94,4 +151,29 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
94151
res.append(chunk)
95152

96153
assert res == data
154+
155+
# add decimal check
156+
@pytest.mark.parametrize("field_name,field_type,cast_statement",
157+
[("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"),
158+
("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"),
159+
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"),
160+
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
161+
("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"),
162+
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163+
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
164+
def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement):
165+
assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement
166+
167+
168+
def test_use_cast_statements(casting_test_table):
169+
_, conn = casting_test_table
170+
test_rel = conn.sql("SELECT * from test_casting")
171+
casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()]
172+
test_rel = test_rel.project(",".join(casting_statements))
173+
assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()}
174+
dodgy_date_rec = test_rel.pl()[1].to_dicts()[0]
175+
assert (not dodgy_date_rec.get("date_test") and
176+
not dodgy_date_rec.get("basic_model",{}).get("date_field")
177+
and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
178+
)
97179

0 commit comments

Comments
 (0)