From 9ce10842328b8ebf81f6007bde4dd227b947e15e Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 8 Apr 2026 22:59:40 +0100 Subject: [PATCH 1/5] fix: enhance duckdb casting to be less permissive of poorly formatted dates and trim whitespace --- .../implementations/duckdb/duckdb_helpers.py | 82 ++++++++++++++++++ .../test_duckdb/test_duckdb_helpers.py | 83 ++++++++++++++++++- 2 files changed, 163 insertions(+), 2 deletions(-) diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index f5b0fe9..79de393 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -313,3 +313,85 @@ def duckdb_record_index(cls): setattr(cls, "add_record_index", _add_duckdb_record_index) setattr(cls, "drop_record_index", _drop_duckdb_record_index) return cls + +def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str: + return f"try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})" + + +def get_duckdb_cast_statement_from_annotation(element_name:str, + type_annotation: Any, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$", + include_cast: bool = True) -> DuckDBPyType: + type_origin = get_origin(type_annotation) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"list_transform({element_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})" + return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable + return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_duckdb_cast_statement_from_annotation( + f"{element_name}.{field_name}", + field_annotation, + date_regex, + timestamp_regex, + False) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()]) + stmt = f"struct_pack({cast_exprs})" + return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation) + + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + if issubclass(type_, datetime): + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({element_name}) as TIMESTAMP) ELSE NULL END" + return stmt + if issubclass(type_, date): + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{date_regex}') THEN TRY_CAST(TRIM({element_name}) as DATE) ELSE NULL END" + return stmt + duck_type = get_duckdb_type_from_annotation(type_) + if duck_type: + stmt = f"trim({element_name})" + return _cast_as_ddb_type(stmt, type_) if include_cast else stmt + raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index 5c39e36..dba0757 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -3,17 +3,73 @@ import datetime import tempfile from pathlib import Path -from typing import Any +from typing import Any, List import pytest import pyspark.sql.types as pst from duckdb import DuckDBPyRelation, DuckDBPyConnection +from pydantic import BaseModel from pyspark.sql import Row, SparkSession from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( _ddb_read_parquet, - duckdb_rel_to_dictionaries) + duckdb_rel_to_dictionaries, + get_duckdb_cast_statement_from_annotation, + get_duckdb_type_from_annotation) +@pytest.fixture +def casting_test_table(temp_ddb_conn): + _, conn = temp_ddb_conn + conn.sql("""CREATE TABLE test_casting ( + str_test VARCHAR, + int_test VARCHAR, + date_test VARCHAR, + timestamp_test VARCHAR, + list_int_field VARCHAR[], + basic_model STRUCT(str_field VARCHAR, date_field VARCHAR), + another_model STRUCT(unique_id VARCHAR, basic_models STRUCT(str_field VARCHAR, date_field VARCHAR)[]))""") + + conn.sql("""INSERT INTO test_casting + VALUES( + 'good_one', + '1', + '2024-11-13', + '2024-04-15 12:25:36', + ['1', '2', '3'], + {'str_field': 'test', 'date_field': '2024-12-11'}, + {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}), + ( + 'dodgy_dates', + '2', + '24-11-13', + '2024-4-15 12:25:36', + ['4', '5', '6'], + {'str_field': 'test', 'date_field': '202-1-11'}, + {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-5'}]})""") + + + yield temp_ddb_conn + + conn.sql("DROP TABLE IF EXISTS test_casting") + + + +class BasicModel(BaseModel): + str_field: str + date_field: datetime.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: datetime.date + timestamp_test: datetime.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel class TempConnection: """ @@ -25,6 +81,7 @@ def __init__(self, connection: DuckDBPyConnection) -> None: self._connection = connection + @pytest.mark.parametrize( "outpath", [ @@ -94,4 +151,26 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, res.append(chunk) assert res == data + +# add decimal check +@pytest.mark.parametrize("field_name,field_type,cast_statement", + [("str_test", str, "try_cast(trim(str_test) as VARCHAR)"), + ("int_test", int, "try_cast(trim(int_test) as BIGINT)"), + ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(date_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(date_test) as DATE) ELSE NULL END"), + ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(timestamp_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(timestamp_test) as TIMESTAMP) ELSE NULL END"), + ("list_int_field", list[int], "try_cast(list_transform(list_int_field, x -> trim(x)) as BIGINT[])"), + ("basic_model", BasicModel, "try_cast(struct_pack(str_field:= trim(basic_model.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), + ("another_model", AnotherModel, "try_cast(struct_pack(unique_id:= trim(another_model.unique_id),basic_models:= list_transform(another_model.basic_models, x -> struct_pack(str_field:= trim(x.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) +def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement): + assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement + + +def test_use_cast_statements(casting_test_table): + _, conn = casting_test_table + test_rel = conn.sql("SELECT * from test_casting") + casting_statements = [ f"{get_duckdb_cast_statement_from_annotation(fld.name, fld.annotation)} as {fld.name}" for fld in CastingRecord.__fields__.values()] + test_rel = test_rel.project(",".join(casting_statements)) + assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = test_rel.pl()[1].to_dicts()[0] + assert not dodgy_date_rec.get("date_test") and not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) From 414d3acfe172153e80bd26592feec0a9c07b2a7f Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Tue, 14 Apr 2026 12:41:02 +0100 Subject: [PATCH 2/5] feat: integrated duckdb casting into data contract and added initial spark casting --- .../implementations/duckdb/contract.py | 19 ++--- .../implementations/duckdb/duckdb_helpers.py | 33 ++++++--- .../implementations/spark/spark_helpers.py | 73 +++++++++++++++++++ .../test_duckdb/test_duckdb_helpers.py | 19 +++-- 4 files changed, 110 insertions(+), 34 deletions(-) diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 25fb8a7..4ba173a 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -28,6 +28,7 @@ generate_error_casting_entity_message, ) from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( + get_duckdb_cast_statement_from_annotation, duckdb_read_parquet, duckdb_record_index, duckdb_write_parquet, @@ -101,17 +102,6 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument _lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable return self._connection.sql("select * from _lazy_df") - @staticmethod - def generate_ddb_cast_statement( - column_name: str, dtype: DuckDBPyType, null_flag: bool = False - ) -> str: - """Helper method to generate sql statements for casting datatypes (permissively). - Current duckdb python API doesn't play well with this currently. - """ - if not null_flag: - return f'try_cast("{column_name}" AS {dtype}) AS "{column_name}"' - return f'cast(NULL AS {dtype}) AS "{column_name}"' - # pylint: disable=R0914 def apply_data_contract( self, @@ -180,12 +170,13 @@ def apply_data_contract( casting_statements = [ ( - self.generate_ddb_cast_statement(column, dtype) + get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + f""" AS "{column}" """ if column in relation.columns - else self.generate_ddb_cast_statement(column, dtype, null_flag=True) + else f"CAST(NULL AS {ddb_schema[column]}) AS {column}" ) - for column, dtype in ddb_schema.items() + for column, mdl_fld in entity_fields.items() ] + casting_statements.append(f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}") try: relation = relation.project(", ".join(casting_statements)) except Exception as err: # pylint: disable=broad-except diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 79de393..4a4eb45 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -315,30 +315,39 @@ def duckdb_record_index(cls): return cls def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str: - return f"try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})" + return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})""" + +def _ddb_safely_quote_name(field_name:str) -> str: + try: + sep_idx = field_name.rindex(".") + return field_name[:sep_idx + 1] + f"\"{field_name[sep_idx + 1:]}\"" + except ValueError: + return f"\"{field_name}\"" def get_duckdb_cast_statement_from_annotation(element_name:str, type_annotation: Any, date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$", - include_cast: bool = True) -> DuckDBPyType: + parent_element: bool = True) -> str: type_origin = get_origin(type_annotation) + + quoted_name = _ddb_safely_quote_name(element_name) # An `Optional` or `Union` type, check to ensure non-heterogenity. if type_origin is Union: python_type = _get_non_heterogenous_type(get_args(type_annotation)) - return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) + return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) - stmt = f"list_transform({element_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})" - return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation) + stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})" + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) if type_origin is Annotated: python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable - return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, include_cast) # add other expected params here + return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) # add other expected params here # Ensure that we have a concrete type at this point. if not isinstance(type_annotation, type): raise ValueError(f"Unsupported type annotation {type_annotation!r}") @@ -372,9 +381,9 @@ def get_duckdb_cast_statement_from_annotation(element_name:str, raise ValueError( f"No type annotations in dict/dataclass type (got {type_annotation!r})" ) - cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()]) + cast_exprs = ",".join([f"\"{nme}\":= {stmt}" for nme, stmt in fields.items()]) stmt = f"struct_pack({cast_exprs})" - return stmt if not include_cast else _cast_as_ddb_type(stmt, type_annotation) + return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) if type_annotation is list: raise ValueError( @@ -385,13 +394,13 @@ def get_duckdb_cast_statement_from_annotation(element_name:str, for type_ in type_annotation.mro(): if issubclass(type_, datetime): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({element_name}) as TIMESTAMP) ELSE NULL END" + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" return stmt if issubclass(type_, date): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({element_name}), '{date_regex}') THEN TRY_CAST(TRIM({element_name}) as DATE) ELSE NULL END" + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" return stmt duck_type = get_duckdb_type_from_annotation(type_) if duck_type: - stmt = f"trim({element_name})" - return _cast_as_ddb_type(stmt, type_) if include_cast else stmt + stmt = f"trim({quoted_name})" + return _cast_as_ddb_type(stmt, type_) if parent_element else stmt raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 07a4a04..5fc0e5d 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -439,3 +439,76 @@ def spark_record_index(cls): setattr(cls, "add_record_index", _add_spark_record_index) setattr(cls, "drop_record_index", _drop_spark_record_index) return cls + + +def _cast_as_spark_type(field_expr:str, field_type: st.DataType) -> Column: + return sf.expr(field_expr).cast(field_type) + + +def get_spark_cast_statement_from_annotation(element_name:str, + type_annotation: Any, + include_cast: bool = True) -> st.DataType: + type_origin = get_origin(type_annotation) + + # An `Optional` or `Union` type, check to ensure non-heterogenity. + if type_origin is Union: + python_type = _get_non_heterogenous_type(get_args(type_annotation)) + return get_spark_cast_statement_from_annotation(element_name, include_cast) + + # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. + if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): + element_type = _get_non_heterogenous_type(get_args(type_annotation)) + stmt = f"transform({element_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False)})" + return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation) + + if type_origin is Annotated: + python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable + return get_spark_cast_statement_from_annotation(element_name, python_type, include_cast) # add other expected params here + # Ensure that we have a concrete type at this point. + if not isinstance(type_annotation, type): + raise ValueError(f"Unsupported type annotation {type_annotation!r}") + + if ( + # Type hint is a dict subclass, but not dict. Possibly a `TypedDict`. + (issubclass(type_annotation, dict) and type_annotation is not dict) + # Type hint is a dataclass. + or is_dataclass(type_annotation) + # Type hint is a `pydantic` model. + or (type_origin is None and issubclass(type_annotation, BaseModel)) + ): + fields: dict[str, str] = {} + for field_name, field_annotation in get_type_hints(type_annotation).items(): + # Technically non-string keys are disallowed, but people are bad. + if not isinstance(field_name, str): + raise ValueError( + f"Dictionary/Dataclass keys must be strings, got {type_annotation!r}" + ) # pragma: no cover + if get_origin(field_annotation) is ClassVar: + continue + + fields[field_name] = get_spark_cast_statement_from_annotation( + f"{element_name}.{field_name}", + field_annotation, + False) + + if not fields: + raise ValueError( + f"No type annotations in dict/dataclass type (got {type_annotation!r})" + ) + cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()]) + stmt = f"struct_pack({cast_exprs})" + return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation) + + if type_annotation is list: + raise ValueError( + f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" + ) + if type_annotation is dict or type_origin is dict: + raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") + + for type_ in type_annotation.mro(): + duck_type = get_type_from_annotation(type_) + if duck_type: + stmt = f"trim({element_name})" + return _cast_as_spark_type(stmt, type_) if include_cast else stmt + raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index dba0757..268e161 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -154,13 +154,13 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, # add decimal check @pytest.mark.parametrize("field_name,field_type,cast_statement", - [("str_test", str, "try_cast(trim(str_test) as VARCHAR)"), - ("int_test", int, "try_cast(trim(int_test) as BIGINT)"), - ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(date_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(date_test) as DATE) ELSE NULL END"), - ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(timestamp_test), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(timestamp_test) as TIMESTAMP) ELSE NULL END"), - ("list_int_field", list[int], "try_cast(list_transform(list_int_field, x -> trim(x)) as BIGINT[])"), - ("basic_model", BasicModel, "try_cast(struct_pack(str_field:= trim(basic_model.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), - ("another_model", AnotherModel, "try_cast(struct_pack(unique_id:= trim(another_model.unique_id),basic_models:= list_transform(another_model.basic_models, x -> struct_pack(str_field:= trim(x.str_field),date_field:= CASE WHEN REGEXP_MATCHES(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) + [("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"), + ("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"), + ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"), + ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), + ("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"), + ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(basic_model.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.\"date_field\") as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), + ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(another_model.\"unique_id\"),\"basic_models\":= list_transform(another_model.\"basic_models\", x -> struct_pack(\"str_field\":= trim(x.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(x.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.\"date_field\") as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement): assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement @@ -172,5 +172,8 @@ def test_use_cast_statements(casting_test_table): test_rel = test_rel.project(",".join(casting_statements)) assert dict(zip(test_rel.columns, test_rel.dtypes)) == {fld.name: get_duckdb_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} dodgy_date_rec = test_rel.pl()[1].to_dicts()[0] - assert not dodgy_date_rec.get("date_test") and not dodgy_date_rec.get("basic_model",{}).get("date_field") and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) From f9f4d4ff5e2e03ef38db6d759bf27871e1d602db Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 15 Apr 2026 14:24:40 +0100 Subject: [PATCH 3/5] refactor: added further spark cast work with tests, small fixes to duckdb casting --- .../implementations/duckdb/contract.py | 11 ++-- .../implementations/duckdb/duckdb_helpers.py | 58 ++++++++-------- .../implementations/spark/spark_helpers.py | 63 +++++++++++------- src/dve/pipeline/foundry_ddb_pipeline.py | 4 +- .../test_duckdb/test_duckdb_helpers.py | 4 +- .../test_spark}/test_spark_helpers.py | 66 ++++++++++++++++++- 6 files changed, 147 insertions(+), 59 deletions(-) rename tests/test_core_engine/{ => test_backends/test_implementations/test_spark}/test_spark_helpers.py (54%) diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 4ba173a..930e945 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -28,10 +28,10 @@ generate_error_casting_entity_message, ) from dve.core_engine.backends.implementations.duckdb.duckdb_helpers import ( - get_duckdb_cast_statement_from_annotation, duckdb_read_parquet, duckdb_record_index, duckdb_write_parquet, + get_duckdb_cast_statement_from_annotation, get_duckdb_type_from_annotation, relation_is_empty, ) @@ -102,7 +102,7 @@ def create_entity_from_py_iterator( # pylint: disable=unused-argument _lazy_df = pl.LazyFrame(records, polars_schema) # type: ignore # pylint: disable=unused-variable return self._connection.sql("select * from _lazy_df") - # pylint: disable=R0914 + # pylint: disable=R0914,R0915 def apply_data_contract( self, working_dir: URI, @@ -170,13 +170,16 @@ def apply_data_contract( casting_statements = [ ( - get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + f""" AS "{column}" """ + get_duckdb_cast_statement_from_annotation(column, mdl_fld.annotation) + + f""" AS "{column}" """ if column in relation.columns else f"CAST(NULL AS {ddb_schema[column]}) AS {column}" ) for column, mdl_fld in entity_fields.items() ] - casting_statements.append(f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}") + casting_statements.append( + f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301 + ) try: relation = relation.project(", ".join(casting_statements)) except Exception as err: # pylint: disable=broad-except diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 4a4eb45..6fbe2cb 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -314,40 +314,49 @@ def duckdb_record_index(cls): setattr(cls, "drop_record_index", _drop_duckdb_record_index) return cls -def _cast_as_ddb_type(field_expr:str, type_annotation:Any) -> str: + +def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str: return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})""" -def _ddb_safely_quote_name(field_name:str) -> str: + +def _ddb_safely_quote_name(field_name: str) -> str: try: - sep_idx = field_name.rindex(".") - return field_name[:sep_idx + 1] + f"\"{field_name[sep_idx + 1:]}\"" + sep_idx = field_name.index(".") + return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:] except ValueError: - return f"\"{field_name}\"" - - -def get_duckdb_cast_statement_from_annotation(element_name:str, - type_annotation: Any, - date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", - timestamp_regex:str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$", - parent_element: bool = True) -> str: + return f'"{field_name}"' + +# pylint: disable=R0911 +def get_duckdb_cast_statement_from_annotation( + element_name: str, + type_annotation: Any, + parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$", +) -> str: + """Generate casting statements for duckdb relations from type annotations""" type_origin = get_origin(type_annotation) - + quoted_name = _ddb_safely_quote_name(element_name) # An `Optional` or `Union` type, check to ensure non-heterogenity. if type_origin is Union: python_type = _get_non_heterogenous_type(get_args(type_annotation)) - return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, date_regex, timestamp_regex, parent_element + ) # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) - stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, date_regex, timestamp_regex, False)})" + stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) if type_origin is Annotated: python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable - return get_duckdb_cast_statement_from_annotation(element_name, python_type, date_regex, timestamp_regex, parent_element) # add other expected params here + return get_duckdb_cast_statement_from_annotation( + element_name, python_type, date_regex, timestamp_regex, parent_element + ) # add other expected params here # Ensure that we have a concrete type at this point. if not isinstance(type_annotation, type): raise ValueError(f"Unsupported type annotation {type_annotation!r}") @@ -371,17 +380,14 @@ def get_duckdb_cast_statement_from_annotation(element_name:str, continue fields[field_name] = get_duckdb_cast_statement_from_annotation( - f"{element_name}.{field_name}", - field_annotation, - date_regex, - timestamp_regex, - False) + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) if not fields: raise ValueError( f"No type annotations in dict/dataclass type (got {type_annotation!r})" ) - cast_exprs = ",".join([f"\"{nme}\":= {stmt}" for nme, stmt in fields.items()]) + cast_exprs = ",".join([f'"{nme}":= {stmt}' for nme, stmt in fields.items()]) stmt = f"struct_pack({cast_exprs})" return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) @@ -394,13 +400,13 @@ def get_duckdb_cast_statement_from_annotation(element_name:str, for type_ in type_annotation.mro(): if issubclass(type_, datetime): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 return stmt if issubclass(type_, date): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 return stmt - duck_type = get_duckdb_type_from_annotation(type_) + duck_type = get_duckdb_type_from_annotation(type_) if duck_type: - stmt = f"trim({quoted_name})" + stmt = f"trim({quoted_name})" return _cast_as_ddb_type(stmt, type_) if parent_element else stmt raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index 5fc0e5d..e5169a2 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -441,29 +441,41 @@ def spark_record_index(cls): return cls -def _cast_as_spark_type(field_expr:str, field_type: st.DataType) -> Column: - return sf.expr(field_expr).cast(field_type) +def _cast_as_spark_type(field_expr: str, field_type: st.DataType) -> Column: + return sf.expr(field_expr).cast(get_type_from_annotation(field_type)) - -def get_spark_cast_statement_from_annotation(element_name:str, - type_annotation: Any, - include_cast: bool = True) -> st.DataType: +def _spark_safely_quote_name(field_name: str) -> str: + try: + sep_idx = field_name.index(".") + return f'`{field_name[: sep_idx]}`' + field_name[sep_idx:] + except ValueError: + return f'`{field_name}`' + +def get_spark_cast_statement_from_annotation( + element_name: str, type_annotation: Any, parent_element: bool = True, + date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$"): + """Generate casting statements for spark dataframes based on type annotations""" type_origin = get_origin(type_annotation) + + quoted_name = _spark_safely_quote_name(element_name) # An `Optional` or `Union` type, check to ensure non-heterogenity. if type_origin is Union: python_type = _get_non_heterogenous_type(get_args(type_annotation)) - return get_spark_cast_statement_from_annotation(element_name, include_cast) + return get_spark_cast_statement_from_annotation(element_name, python_type, parent_element, date_regex, timestamp_regex) # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) - stmt = f"transform({element_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False)})" - return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation) + stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) if type_origin is Annotated: - python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable - return get_spark_cast_statement_from_annotation(element_name, python_type, include_cast) # add other expected params here + python_type, *_ = get_args(type_annotation) # pylint: disable=unused-variable + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # add other expected params here # Ensure that we have a concrete type at this point. if not isinstance(type_annotation, type): raise ValueError(f"Unsupported type annotation {type_annotation!r}") @@ -487,18 +499,16 @@ def get_spark_cast_statement_from_annotation(element_name:str, continue fields[field_name] = get_spark_cast_statement_from_annotation( - f"{element_name}.{field_name}", - field_annotation, - False) + f"{element_name}.{field_name}", field_annotation, False, date_regex, timestamp_regex + ) if not fields: raise ValueError( f"No type annotations in dict/dataclass type (got {type_annotation!r})" ) - cast_exprs = ",".join([f'{nme}:= {stmt}' for nme, stmt in fields.items()]) - stmt = f"struct_pack({cast_exprs})" - return stmt if not include_cast else _cast_as_spark_type(stmt, type_annotation) - + cast_exprs = ",".join([f"{stmt} AS `{nme}`" for nme, stmt in fields.items()]) + stmt = f"struct({cast_exprs})" + return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) if type_annotation is list: raise ValueError( f"List must have type annotation (e.g. `List[str]`), got {type_annotation!r}" @@ -507,8 +517,15 @@ def get_spark_cast_statement_from_annotation(element_name:str, raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") for type_ in type_annotation.mro(): - duck_type = get_type_from_annotation(type_) - if duck_type: - stmt = f"trim({element_name})" - return _cast_as_spark_type(stmt, type_) if include_cast else stmt - raise ValueError(f"No equivalent DuckDB type for {type_annotation!r}") + if issubclass(type_, dt.datetime): + stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + elif issubclass(type_, dt.date): + stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + else: + spark_type = get_type_from_annotation(type_) + if spark_type: + stmt = f"trim({quoted_name})" + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + raise ValueError(f"No equivalent Spark type for {type_annotation!r}") diff --git a/src/dve/pipeline/foundry_ddb_pipeline.py b/src/dve/pipeline/foundry_ddb_pipeline.py index 3b0e55f..21cac56 100644 --- a/src/dve/pipeline/foundry_ddb_pipeline.py +++ b/src/dve/pipeline/foundry_ddb_pipeline.py @@ -42,13 +42,13 @@ def persist_audit_records(self, submission_info: SubmissionInfo) -> URI: write_to.parent.mkdir(parents=True, exist_ok=True) write_to = write_to.as_posix() self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._processing_status.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "processing_status.parquet"), ) self.write_parquet( # type: ignore # pylint: disable=E1101 - self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 + self._audit_tables._submission_statistics.get_relation().filter( # pylint: disable=W0212 f"submission_id = '{submission_info.submission_id}'" ), fh.joinuri(write_to, "submission_statistics.parquet"), diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index 268e161..bf77d98 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -159,8 +159,8 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"), ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), ("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"), - ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(basic_model.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(basic_model.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(basic_model.\"date_field\") as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), - ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(another_model.\"unique_id\"),\"basic_models\":= list_transform(another_model.\"basic_models\", x -> struct_pack(\"str_field\":= trim(x.\"str_field\"),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(x.\"date_field\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(x.\"date_field\") as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) + ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), + ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) def test_get_duckdb_cast_statement_from_annotation(field_name, field_type, cast_statement): assert get_duckdb_cast_statement_from_annotation(field_name, field_type) == cast_statement diff --git a/tests/test_core_engine/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py similarity index 54% rename from tests/test_core_engine/test_spark_helpers.py rename to tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index a3f167d..dee0525 100644 --- a/tests/test_core_engine/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -12,17 +12,56 @@ from pydantic.types import condecimal from pyspark.sql import DataFrame, SparkSession from pyspark.sql import types as st -from pyspark.sql.functions import col +from pyspark.sql.functions import col, expr +from pyspark.sql.types import ArrayType, DateType, LongType, StringType, StructField, StructType, TimestampType from typing_extensions import Annotated, TypedDict from dve.core_engine.backends.implementations.spark.spark_helpers import ( DecimalConfig, create_udf, + get_spark_cast_statement_from_annotation, get_type_from_annotation, object_to_spark_literal, ) -from ..fixtures import spark # pylint: disable=unused-import +from .....fixtures import spark # pylint: disable=unused-import + +@pytest.fixture +def casting_dataframe(spark): + data = [{"str_test": "good_one", "int_test": "1", "date_test": "2024-11-13", "timestamp_test": "2024-04-15 12:25:36", + "list_int_field":['1', '2', '3'], "basic_model": {'str_field': 'test', 'date_field': '2024-12-11'}, + "another_model": {'unique_id': '1', "basic_models": [{'str_field': 'test_nest', 'date_field': '2020-01-04'}, {'str_field': 'test_nest2', 'date_field': '2020-01-05'}]}}, + {"str_test": "dodgy_dates", "int_test": "2", "date_test": "24-11-13", "timestamp_test": "2024-4-15 12:25:36", + "list_int_field":['4', '5', '6'], "basic_model": {'str_field': 'test', 'date_field': '202-12-11'}, + "another_model": {'unique_id': '2', "basic_models": [{'str_field': 'test_dd', 'date_field': '20-01-04'}, {'str_field': 'test_dd2', 'date_field': '2020-1-05'}]}}] + + bm_schema = StructType([StructField("str_field", StringType()), StructField("date_field", StringType())]) + + schema = StructType([StructField("str_test", StringType()), StructField("int_test", StringType()), StructField("date_test", StringType()), + StructField("timestamp_test", StringType()), StructField("list_int_field", ArrayType(StringType())), + StructField("basic_model", bm_schema), + StructField("another_model", StructType([StructField("unique_id", StringType()), StructField("basic_models", ArrayType(bm_schema))]))]) + yield spark.createDataFrame(data, schema=schema) + + + + +class BasicModel(BaseModel): + str_field: str + date_field: dt.date + +class AnotherModel(BaseModel): + unique_id: int + basic_models: List[BasicModel] + +class CastingRecord(BaseModel): + str_test: str + int_test: int + date_test: dt.date + timestamp_test: dt.datetime + list_int_field: list[int] + basic_model: BasicModel + another_model: AnotherModel EXPECTED_STRUCT = st.StructType( [ @@ -203,3 +242,26 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any): """ with pytest.raises(ValueError): object_to_spark_literal(obj) + +@pytest.mark.parametrize("field_name,field_type,expression,spark_type", + [("str_test", str, "trim(`str_test`)", StringType()), + ("int_test", int, "trim(`int_test`)", LongType()), + ("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()), + ("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), + ("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)), + ("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])), + ("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))]) +def test_get_spark_cast_statement_from_annotation(field_name, field_type, expression, spark_type): + assert str(get_spark_cast_statement_from_annotation(field_name, field_type)) == str(expr(expression).cast(spark_type)) + + +def test_use_cast_statements(spark, casting_dataframe): + casting_statements = [ get_spark_cast_statement_from_annotation(fld.name, fld.annotation).alias(fld.name) for fld in CastingRecord.__fields__.values()] + cast_df = casting_dataframe.select(*casting_statements) + assert {fld.name: fld.dataType for fld in cast_df.schema} == {fld.name: get_type_from_annotation(fld.annotation) for fld in CastingRecord.__fields__.values()} + dodgy_date_rec = [rw.asDict(True) for rw in cast_df.collect()][1] + assert (not dodgy_date_rec.get("date_test") and + not dodgy_date_rec.get("basic_model",{}).get("date_field") + and all(not val.get("date_field") for val in dodgy_date_rec.get("another_model",{}).get("basic_models",[])) + ) + assert cast_df \ No newline at end of file From 78fa448a7df298aa5b1e66539db9dc7b4b54e2e5 Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Wed, 15 Apr 2026 21:58:02 +0100 Subject: [PATCH 4/5] style: address linting issues --- .../implementations/duckdb/contract.py | 2 +- .../implementations/duckdb/duckdb_helpers.py | 17 ++++---- .../implementations/spark/spark_helpers.py | 41 +++++++++++-------- .../test_duckdb/test_duckdb_helpers.py | 2 +- .../test_spark/test_spark_helpers.py | 2 +- 5 files changed, 38 insertions(+), 26 deletions(-) diff --git a/src/dve/core_engine/backends/implementations/duckdb/contract.py b/src/dve/core_engine/backends/implementations/duckdb/contract.py index 930e945..3595716 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/contract.py +++ b/src/dve/core_engine/backends/implementations/duckdb/contract.py @@ -178,7 +178,7 @@ def apply_data_contract( for column, mdl_fld in entity_fields.items() ] casting_statements.append( - f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301 + f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301 ) try: relation = relation.project(", ".join(casting_statements)) diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 6fbe2cb..0ff708e 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -316,23 +316,26 @@ def duckdb_record_index(cls): def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str: + """Cast to Duck DB type""" return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})""" def _ddb_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" try: sep_idx = field_name.index(".") return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:] except ValueError: return f'"{field_name}"' -# pylint: disable=R0911 + +# pylint: disable=R0801,R0911 def get_duckdb_cast_statement_from_annotation( element_name: str, type_annotation: Any, parent_element: bool = True, date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", - timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$", + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 ) -> str: """Generate casting statements for duckdb relations from type annotations""" type_origin = get_origin(type_annotation) @@ -343,19 +346,19 @@ def get_duckdb_cast_statement_from_annotation( if type_origin is Union: python_type = _get_non_heterogenous_type(get_args(type_annotation)) return get_duckdb_cast_statement_from_annotation( - element_name, python_type, date_regex, timestamp_regex, parent_element + element_name, python_type, parent_element, date_regex, timestamp_regex ) # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) - stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation) if type_origin is Annotated: python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable return get_duckdb_cast_statement_from_annotation( - element_name, python_type, date_regex, timestamp_regex, parent_element + element_name, python_type, parent_element, date_regex, timestamp_regex ) # add other expected params here # Ensure that we have a concrete type at this point. if not isinstance(type_annotation, type): @@ -400,10 +403,10 @@ def get_duckdb_cast_statement_from_annotation( for type_ in type_annotation.mro(): if issubclass(type_, datetime): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 return stmt if issubclass(type_, date): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 + stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 return stmt duck_type = get_duckdb_type_from_annotation(type_) if duck_type: diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index e5169a2..c91cde2 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -441,34 +441,44 @@ def spark_record_index(cls): return cls -def _cast_as_spark_type(field_expr: str, field_type: st.DataType) -> Column: +def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column: + """Cast to spark type""" return sf.expr(field_expr).cast(get_type_from_annotation(field_type)) + def _spark_safely_quote_name(field_name: str) -> str: + """Quote field names in case reserved""" try: sep_idx = field_name.index(".") - return f'`{field_name[: sep_idx]}`' + field_name[sep_idx:] + return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:] except ValueError: - return f'`{field_name}`' + return f"`{field_name}`" + +# pylint: disable=R0801 def get_spark_cast_statement_from_annotation( - element_name: str, type_annotation: Any, parent_element: bool = True, + element_name: str, + type_annotation: Any, + parent_element: bool = True, date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", - timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$"): + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 +): """Generate casting statements for spark dataframes based on type annotations""" type_origin = get_origin(type_annotation) - + quoted_name = _spark_safely_quote_name(element_name) # An `Optional` or `Union` type, check to ensure non-heterogenity. if type_origin is Union: python_type = _get_non_heterogenous_type(get_args(type_annotation)) - return get_spark_cast_statement_from_annotation(element_name, python_type, parent_element, date_regex, timestamp_regex) + return get_spark_cast_statement_from_annotation( + element_name, python_type, parent_element, date_regex, timestamp_regex + ) # Type hint is e.g. `List[str]`, check to ensure non-heterogenity. if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)): element_type = _get_non_heterogenous_type(get_args(type_annotation)) - stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 + stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301 return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation) if type_origin is Annotated: @@ -518,14 +528,13 @@ def get_spark_cast_statement_from_annotation( for type_ in type_annotation.mro(): if issubclass(type_, dt.datetime): - stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 return _cast_as_spark_type(stmt, type_) if parent_element else stmt - elif issubclass(type_, dt.date): - stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + if issubclass(type_, dt.date): + stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + return _cast_as_spark_type(stmt, type_) if parent_element else stmt + spark_type = get_type_from_annotation(type_) + if spark_type: + stmt = f"trim({quoted_name})" return _cast_as_spark_type(stmt, type_) if parent_element else stmt - else: - spark_type = get_type_from_annotation(type_) - if spark_type: - stmt = f"trim({quoted_name})" - return _cast_as_spark_type(stmt, type_) if parent_element else stmt raise ValueError(f"No equivalent Spark type for {type_annotation!r}") diff --git a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py index bf77d98..19e96e2 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py @@ -157,7 +157,7 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection, [("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"), ("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"), ("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"), - ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), + ("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"), ("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"), ("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"), ("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")]) diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index dee0525..97285a9 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -247,7 +247,7 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any): [("str_test", str, "trim(`str_test`)", StringType()), ("int_test", int, "trim(`int_test`)", LongType()), ("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()), - ("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), + ("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), ("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)), ("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])), ("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))]) From 7d61282922dfd4793b0358bc0b18f63121ad7742 Mon Sep 17 00:00:00 2001 From: stevenhsd <56357022+stevenhsd@users.noreply.github.com> Date: Thu, 16 Apr 2026 16:05:09 +0100 Subject: [PATCH 5/5] refactor: add in time type for duckdb casting in contract and fix spark test involving regexp for spark casting in contract --- .../backends/implementations/duckdb/duckdb_helpers.py | 11 ++++++++--- .../backends/implementations/spark/spark_helpers.py | 7 ++++--- .../test_spark/test_spark_helpers.py | 2 +- 3 files changed, 13 insertions(+), 7 deletions(-) diff --git a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py index 0ff708e..394cd01 100644 --- a/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py +++ b/src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py @@ -329,13 +329,14 @@ def _ddb_safely_quote_name(field_name: str) -> str: return f'"{field_name}"' -# pylint: disable=R0801,R0911 +# pylint: disable=R0801,R0911,R0912 def get_duckdb_cast_statement_from_annotation( element_name: str, type_annotation: Any, parent_element: bool = True, date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 + time_regex: str = r"^[0-9]{2}:[0-9]{2}:[0-9]{2}$", ) -> str: """Generate casting statements for duckdb relations from type annotations""" type_origin = get_origin(type_annotation) @@ -402,11 +403,15 @@ def get_duckdb_cast_statement_from_annotation( raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first if issubclass(type_, datetime): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301 return stmt if issubclass(type_, date): - stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301 + return stmt + if issubclass(type_, time): + stmt = rf"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{time_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIME) ELSE NULL END" # pylint: disable=C0301 return stmt duck_type = get_duckdb_type_from_annotation(type_) if duck_type: diff --git a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py index c91cde2..ced985a 100644 --- a/src/dve/core_engine/backends/implementations/spark/spark_helpers.py +++ b/src/dve/core_engine/backends/implementations/spark/spark_helpers.py @@ -461,7 +461,7 @@ def get_spark_cast_statement_from_annotation( type_annotation: Any, parent_element: bool = True, date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$", - timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 + timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301 ): """Generate casting statements for spark dataframes based on type annotations""" type_origin = get_origin(type_annotation) @@ -527,11 +527,12 @@ def get_spark_cast_statement_from_annotation( raise ValueError(f"dict must be `typing.TypedDict` subclass, got {type_annotation!r}") for type_ in type_annotation.mro(): + # datetime is subclass of date, so needs to be handled first if issubclass(type_, dt.datetime): - stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 return _cast_as_spark_type(stmt, type_) if parent_element else stmt if issubclass(type_, dt.date): - stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 + stmt = rf"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301 return _cast_as_spark_type(stmt, type_) if parent_element else stmt spark_type = get_type_from_annotation(type_) if spark_type: diff --git a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py index 97285a9..7502673 100644 --- a/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py +++ b/tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py @@ -247,7 +247,7 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any): [("str_test", str, "trim(`str_test`)", StringType()), ("int_test", int, "trim(`int_test`)", LongType()), ("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()), - ("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), + ("timestamp_test", dt.datetime, r"CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\\+|\\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()), ("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)), ("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])), ("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))])