Skip to content

Commit 9ce1084

Browse files
committed
fix: enhance duckdb casting to be less permissive of poorly formatted dates and trim whitespace
1 parent db0d300 commit 9ce1084

2 files changed

Lines changed: 163 additions & 2 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -313,3 +313,85 @@ def duckdb_record_index(cls):
313313
setattr(cls, "add_record_index", _add_duckdb_record_index)
314314
setattr(cls, "drop_record_index", _drop_duckdb_record_index)
315315
return cls
316+
317+
def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str:
318+
return f"try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"
319+
320+
321+
def get_duckdb_cast_statement_from_annotation(element_name:str,
322+
type_annotation: Any,
323+
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
324+
timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$",
325+
include_cast: bool = True) -> DuckDBPyType:
326+
type_origin = get_origin(type_annotation)
327+
328+
# An `Optional` or `Union` type, check to ensure non-heterogenity.
329+
if type_origin is Union:
330+
python_type = _get_non_heterogenous_type(get_args(type_annotation))
331+
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast)
332+
333+
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
334+
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
335+
element_type = _get_non_heterogenous_type(get_args(type_annotation))
336+
stmt = f"list_transform({element_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})"
337+
return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation)
338+
339+
if type_origin is Annotated:
340+
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
341+
return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) # add other expected params here
342+
# Ensure that we have a concrete type at this point.
343+
if not isinstance(type_annotation, type):
344+
raise ValueError(f"Unsupported type annotation {type_annotation!r}")
345+
346+
if (
347+
# Type hint is a dict subclass, but not dict. Possibly a `TypedDict`.
348+
(issubclass(type_annotation, dict) and type_annotation is not dict)
349+
# Type hint is a dataclass.
350+
or is_dataclass(type_annotation)
351+
# Type hint is a `pydantic` model.
352+
or (type_origin is None and issubclass(type_annotation, BaseModel))
353+
):
354+
fields: dict[str, str] = {}
355+
for field_name, field_annotation in get_type_hints(type_annotation).items():
356+
# Technically non-string keys are disallowed, but people are bad.
357+
if not isinstance(field_name, str):
358+
raise ValueError(
359+
f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}"
360+
) # pragma: no cover
361+
if get_origin(field_annotation) is ClassVar:
362+
continue
363+
364+
fields[field_name] = get_duckdb_cast_statement_from_annotation(
365+
f"{element_name}.{field_name}",
366+
field_annotation,
367+
date_regex,
368+
timestamp_regex,
369+
False)
370+
371+
if not fields:
372+
raise ValueError(
373+
f"No type annotations in dict/dataclass type (got {type_annotation!r})"
374+
)
375+
cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()])
376+
stmt = f"struct_pack({cast_exprs})"
377+
return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation)
378+
379+
if type_annotation is list:
380+
raise ValueError(
381+
f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}"
382+
)
383+
if type_annotation is dict or type_origin is dict:
384+
raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}")
385+
386+
for type_ in type_annotation.mro():
387+
if issubclass(type_, datetime):
388+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({element_name}) as TIMESTAMP) ELSE NULL END"
389+
return stmt
390+
if issubclass(type_, date):
391+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{date_regex}') THEN TRY_CAST(TRIM({element_name}) as DATE) ELSE NULL END"
392+
return stmt
393+
duck_type = get_duckdb_type_from_annotation(type_)
394+
if duck_type:
395+
stmt = f"trim({element_name})"
396+
return _cast_as_ddb_type(stmt, type_) if include_cast else stmt
397+
raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}")

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py

Lines changed: 81 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,17 +3,73 @@
33
import datetime
44
import tempfile
55
from pathlib import Path
6-
from typing import Any
6+
from typing import Any, List
77

88
import pytest
99
import pyspark.sql.types as pst
1010
from duckdb import DuckDBPyRelation, DuckDBPyConnection
11+
from pydantic import BaseModel
1112
from pyspark.sql import Row, SparkSession
1213

1314
from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import (
1415
_ddb_read_parquet,
15-
duckdb_rel_to_dictionaries)
16+
duckdb_rel_to_dictionaries,
17+
get_duckdb_cast_statement_from_annotation,
18+
get_duckdb_type_from_annotation)
1619

20+
@pytest.fixture
21+
def casting_test_table(temp_ddb_conn):
22+
_, conn = temp_ddb_conn
23+
conn.sql("""CREATE TABLE test_casting (
24+
str_test VARCHAR,
25+
int_test VARCHAR,
26+
date_test VARCHAR,
27+
timestamp_test VARCHAR,
28+
list_int_field VARCHAR[],
29+
basic_model STRUCT(str_field VARCHAR, date_field VARCHAR),
30+
another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""")
31+
32+
conn.sql("""INSERT INTO test_casting
33+
VALUES(
34+
'good_one',
35+
'1',
36+
'2024-11-13',
37+
'2024-04-15 12:25:36',
38+
['1', '2', '3'],
39+
{'str_field': 'test', 'date_field': '2024-12-11'},
40+
{'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}),
41+
(
42+
'dodgy_dates',
43+
'2',
44+
'24-11-13',
45+
'2024-4-15 12:25:36',
46+
['4', '5', '6'],
47+
{'str_field': 'test', 'date_field': '202-1-11'},
48+
{'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""")
49+
50+
51+
yield temp_ddb_conn
52+
53+
conn.sql("DROP TABLE IF EXISTS test_casting")
54+
55+
56+
57+
class BasicModel(BaseModel):
58+
str_field: str
59+
date_field: datetime.date
60+
61+
class AnotherModel(BaseModel):
62+
unique_id: int
63+
basic_models: List[BasicModel]
64+
65+
class CastingRecord(BaseModel):
66+
str_test: str
67+
int_test: int
68+
date_test: datetime.date
69+
timestamp_test: datetime.datetime
70+
list_int_field: list[int]
71+
basic_model: BasicModel
72+
another_model: AnotherModel
1773

1874
class TempConnection:
1975
"""
@@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None:
2581
self._connection = connection
2682

2783

84+
2885
@pytest.mark.parametrize(
2986
"outpath",
3087
[
@@ -94,4 +151,26 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
94151
res.append(chunk)
95152

96153
assert res == data
154+
155+
# add decimal check
156+
@pytest.mark.parametrize("field_name,field_type,cast_statement",
157+
[("str_test", str, "try_cast(trim(str_test) as VARCHAR)"),
158+
("int_test", int, "try_cast(trim(int_test) as BIGINT)"),
159+
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(date_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(date_test) as DATE) ELSE NULL END"),
160+
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(timestamp_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(timestamp_test) as TIMESTAMP) ELSE NULL END"),
161+
("list_int_field", list[int], "try_cast(list_transform(list_int_field, x -> trim(x)) as BIGINT[])"),
162+
("basic_model", BasicModel, "try_cast(struct_pack(str_field:= trim(basic_model.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163+
("another_model", AnotherModel, "try_cast(struct_pack(unique_id:= trim(another_model.unique_id),basic_models:= list_transform(another_model.basic_models, x -> struct_pack(str_field:= trim(x.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])
164+
def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement):
165+
assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement
166+
167+
168+
def test_use_cast_statements(casting_test_table):
169+
_, conn = casting_test_table
170+
test_rel = conn.sql("SELECT * from test_casting")
171+
casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()]
172+
test_rel = test_rel.project(",".join(casting_statements))
173+
assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()}
174+
dodgy_date_rec = test_rel.pl()[1].to_dicts()[0]
175+
assert not dodgy_date_rec.get("date_test") and not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[]))
97176

0 commit comments

Comments
 (0)