Skip to content

Commit 78fa448

Browse files
committed
style: address linting issues
1 parent f9f4d4f commit 78fa448

5 files changed

Lines changed: 38 additions & 26 deletions

File tree

src/dve/core_engine/backends/implementations/duckdb/contract.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def apply_data_contract(
178178
for column, mdl_fld in entity_fields.items()
179179
]
180180
casting_statements.append(
181-
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
181+
f"CAST({RECORD_INDEX_COLUMN_NAME} AS {get_duckdb_type_from_annotation(int)}) AS {RECORD_INDEX_COLUMN_NAME}" # pylint: disable=C0301
182182
)
183183
try:
184184
relation = relation.project(", ".join(casting_statements))

src/dve/core_engine/backends/implementations/duckdb/duckdb_helpers.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -316,23 +316,26 @@ def duckdb_record_index(cls):
316316

317317

318318
def _cast_as_ddb_type(field_expr: str, type_annotation: Any) -> str:
319+
"""Cast to Duck DB type"""
319320
return f"""try_cast({field_expr} as {get_duckdb_type_from_annotation(type_annotation)})"""
320321

321322

322323
def _ddb_safely_quote_name(field_name: str) -> str:
324+
"""Quote field names in case reserved"""
323325
try:
324326
sep_idx = field_name.index(".")
325327
return f'"{field_name[: sep_idx]}"' + field_name[sep_idx:]
326328
except ValueError:
327329
return f'"{field_name}"'
328330

329-
# pylint: disable=R0911
331+
332+
# pylint: disable=R0801,R0911
330333
def get_duckdb_cast_statement_from_annotation(
331334
element_name: str,
332335
type_annotation: Any,
333336
parent_element: bool = True,
334337
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
335-
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$",
338+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
336339
) -> str:
337340
"""Generate casting statements for duckdb relations from type annotations"""
338341
type_origin = get_origin(type_annotation)
@@ -343,19 +346,19 @@ def get_duckdb_cast_statement_from_annotation(
343346
if type_origin is Union:
344347
python_type = _get_non_heterogenous_type(get_args(type_annotation))
345348
return get_duckdb_cast_statement_from_annotation(
346-
element_name, python_type, date_regex, timestamp_regex, parent_element
349+
element_name, python_type, parent_element, date_regex, timestamp_regex
347350
)
348351

349352
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
350353
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
351354
element_type = _get_non_heterogenous_type(get_args(type_annotation))
352-
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
355+
stmt = f"list_transform({quoted_name}, x -> {get_duckdb_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
353356
return stmt if not parent_element else _cast_as_ddb_type(stmt, type_annotation)
354357

355358
if type_origin is Annotated:
356359
python_type, *other_args = get_args(type_annotation) # pylint: disable=unused-variable
357360
return get_duckdb_cast_statement_from_annotation(
358-
element_name, python_type, date_regex, timestamp_regex, parent_element
361+
element_name, python_type, parent_element, date_regex, timestamp_regex
359362
) # add other expected params here
360363
# Ensure that we have a concrete type at this point.
361364
if not isinstance(type_annotation, type):
@@ -400,10 +403,10 @@ def get_duckdb_cast_statement_from_annotation(
400403

401404
for type_ in type_annotation.mro():
402405
if issubclass(type_, datetime):
403-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
406+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{timestamp_regex}') THEN TRY_CAST(TRIM({quoted_name}) as TIMESTAMP) ELSE NULL END" # pylint: disable=C0301
404407
return stmt
405408
if issubclass(type_, date):
406-
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
409+
stmt = f"CASE WHEN REGEXP_MATCHES(TRIM({quoted_name}), '{date_regex}') THEN TRY_CAST(TRIM({quoted_name}) as DATE) ELSE NULL END" # pylint: disable=C0301
407410
return stmt
408411
duck_type = get_duckdb_type_from_annotation(type_)
409412
if duck_type:

src/dve/core_engine/backends/implementations/spark/spark_helpers.py

Lines changed: 25 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -441,34 +441,44 @@ def spark_record_index(cls):
441441
return cls
442442

443443

444-
def _cast_as_spark_type(field_expr: str, field_type: st.DataType) -> Column:
444+
def _cast_as_spark_type(field_expr: str, field_type: Any) -> Column:
445+
"""Cast to spark type"""
445446
return sf.expr(field_expr).cast(get_type_from_annotation(field_type))
446447

448+
447449
def _spark_safely_quote_name(field_name: str) -> str:
450+
"""Quote field names in case reserved"""
448451
try:
449452
sep_idx = field_name.index(".")
450-
return f'`{field_name[: sep_idx]}`' + field_name[sep_idx:]
453+
return f"`{field_name[: sep_idx]}`" + field_name[sep_idx:]
451454
except ValueError:
452-
return f'`{field_name}`'
455+
return f"`{field_name}`"
456+
453457

458+
# pylint: disable=R0801
454459
def get_spark_cast_statement_from_annotation(
455-
element_name: str, type_annotation: Any, parent_element: bool = True,
460+
element_name: str,
461+
type_annotation: Any,
462+
parent_element: bool = True,
456463
date_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}$",
457-
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$"):
464+
timestamp_regex: str = r"^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$", # pylint: disable=C0301
465+
):
458466
"""Generate casting statements for spark dataframes based on type annotations"""
459467
type_origin = get_origin(type_annotation)
460-
468+
461469
quoted_name = _spark_safely_quote_name(element_name)
462470

463471
# An `Optional` or `Union` type, check to ensure non-heterogenity.
464472
if type_origin is Union:
465473
python_type = _get_non_heterogenous_type(get_args(type_annotation))
466-
return get_spark_cast_statement_from_annotation(element_name, python_type, parent_element, date_regex, timestamp_regex)
474+
return get_spark_cast_statement_from_annotation(
475+
element_name, python_type, parent_element, date_regex, timestamp_regex
476+
)
467477

468478
# Type hint is e.g. `List[str]`, check to ensure non-heterogenity.
469479
if type_origin is list or (isinstance(type_origin, type) and issubclass(type_origin, list)):
470480
element_type = _get_non_heterogenous_type(get_args(type_annotation))
471-
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
481+
stmt = f"transform({quoted_name}, x -> {get_spark_cast_statement_from_annotation('x',element_type, False, date_regex, timestamp_regex)})" # pylint: disable=C0301
472482
return stmt if not parent_element else _cast_as_spark_type(stmt, type_annotation)
473483

474484
if type_origin is Annotated:
@@ -518,14 +528,13 @@ def get_spark_cast_statement_from_annotation(
518528

519529
for type_ in type_annotation.mro():
520530
if issubclass(type_, dt.datetime):
521-
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
531+
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{timestamp_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
522532
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
523-
elif issubclass(type_, dt.date):
524-
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
533+
if issubclass(type_, dt.date):
534+
stmt = f"CASE WHEN REGEXP(TRIM({quoted_name}), '{date_regex}') THEN TRIM({quoted_name}) ELSE NULL END" # pylint: disable=C0301
535+
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
536+
spark_type = get_type_from_annotation(type_)
537+
if spark_type:
538+
stmt = f"trim({quoted_name})"
525539
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
526-
else:
527-
spark_type = get_type_from_annotation(type_)
528-
if spark_type:
529-
stmt = f"trim({quoted_name})"
530-
return _cast_as_spark_type(stmt, type_) if parent_element else stmt
531540
raise ValueError(f"No equivalent Spark type for {type_annotation!r}")

tests/test_core_engine/test_backends/test_implementations/test_duckdb/test_duckdb_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def test_duckdb_rel_to_dictionaries(temp_ddb_conn: DuckDBPyConnection,
157157
[("str_test", str, "try_cast(trim(\"str_test\") as VARCHAR)"),
158158
("int_test", int, "try_cast(trim(\"int_test\") as BIGINT)"),
159159
("date_test", datetime.date,"CASE WHEN REGEXP_MATCHES(TRIM(\"date_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"date_test\") as DATE) ELSE NULL END"),
160-
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
160+
("timestamp_test", datetime.datetime,"CASE WHEN REGEXP_MATCHES(TRIM(\"timestamp_test\"), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRY_CAST(TRIM(\"timestamp_test\") as TIMESTAMP) ELSE NULL END"),
161161
("list_int_field", list[int], "try_cast(list_transform(\"list_int_field\", x -> trim(\"x\")) as BIGINT[])"),
162162
("basic_model", BasicModel, "try_cast(struct_pack(\"str_field\":= trim(\"basic_model\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"basic_model\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"basic_model\".date_field) as DATE) ELSE NULL END) as STRUCT(str_field VARCHAR, date_field DATE))"),
163163
("another_model", AnotherModel, "try_cast(struct_pack(\"unique_id\":= trim(\"another_model\".unique_id),\"basic_models\":= list_transform(\"another_model\".basic_models, x -> struct_pack(\"str_field\":= trim(\"x\".str_field),\"date_field\":= CASE WHEN REGEXP_MATCHES(TRIM(\"x\".date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRY_CAST(TRIM(\"x\".date_field) as DATE) ELSE NULL END))) as STRUCT(unique_id BIGINT, basic_models STRUCT(str_field VARCHAR, date_field DATE)[]))")])

tests/test_core_engine/test_backends/test_implementations/test_spark/test_spark_helpers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,7 @@ def test_object_to_spark_literal_blocks_some_footguns(obj: Any):
247247
[("str_test", str, "trim(`str_test`)", StringType()),
248248
("int_test", int, "trim(`int_test`)", LongType()),
249249
("date_test", dt.date, "CASE WHEN REGEXP(TRIM(`date_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`date_test`) ELSE NULL END", DateType()),
250-
("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2} [0-9]{2}:[0-9]{2}:[0-9]{2}$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()),
250+
("timestamp_test", dt.datetime, "CASE WHEN REGEXP(TRIM(`timestamp_test`), '^[0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}((\+|\-)[0-9]{2}:[0-9]{2})?$') THEN TRIM(`timestamp_test`) ELSE NULL END", TimestampType()),
251251
("list_int_field", list[int], "transform(`list_int_field`, x -> trim(`x`))", ArrayType(LongType(), True)),
252252
("basic_model", BasicModel, "struct(trim(`basic_model`.str_field) as str_field, CASE WHEN REGEXP(TRIM(`basic_model`.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(`basic_model`.date_field) ELSE NULL END as date_field)", StructType([StructField("str_field", StringType(), True), StructField("date_field", DateType(), True)])),
253253
("another_model", AnotherModel, "struct(trim(`another_model`.unique_id) as unique_id, transform(`another_model`.basic_models, x -> struct(trim(x.str_field) as str_field, CASE WHEN REGEXP(TRIM(x.date_field), '^[0-9]{4}-[0-9]{2}-[0-9]{2}$') THEN TRIM(x.date_field) ELSE NULL END as date_field)) as basic_models)", StructType([StructField("unique_id", LongType(), True), StructField("basic_models", ArrayType(StructType([StructField("str_field", StringType()), StructField("date_field", DateType(), True)])))]))])

0 commit comments

Comments
 (0)