Skip to content

Commit b3fc0d8

Browse files
committed
refactor: Address review comments
1 parent 0191dd3 commit b3fc0d8

8 files changed

Lines changed: 109 additions & 84 deletions

File tree

src/dve/core_engine/backends/base/reader.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -117,7 +117,7 @@ def read_to_entity_type(
117117
"""
118118
if entity_name == Iterator[dict[str, Any]]:
119119
return self.read_to_py_iterator(resource, entity_name, schema) # type: ignore
120-
120+
121121
self.raise_if_not_sensible_file(resource, entity_name)
122122

123123
try:
@@ -141,12 +141,12 @@ def write_parquet(
141141
142142
"""
143143
raise NotImplementedError(f"write_parquet not implemented in {self.__class__}")
144-
144+
145145
@staticmethod
146146
def _check_likely_text_file(resource: URI) -> bool:
147147
"""Quick sense check of file to see if it looks like text
148-
- not 100% full proof, but hopefully enough to weed out most
149-
non-text files"""
148+
- not 100% full proof, but hopefully enough to weed out most
149+
non-text files"""
150150
with open_stream(resource, "rb") as fle:
151151
start_chunk = fle.read(4096)
152152
# check for BOM character - utf-16 can contain NULL bytes
@@ -156,8 +156,10 @@ def _check_likely_text_file(resource: URI) -> bool:
156156
if b"\x00" in start_chunk:
157157
return False
158158
return True
159-
160-
def raise_if_not_sensible_file(self, resource: URI, entity_name:str):
159+
160+
def raise_if_not_sensible_file(self, resource: URI, entity_name: str):
161+
"""Sense check that the file is a text file. Raise error if doesn't
162+
appear to be the case."""
161163
if not self._check_likely_text_file(resource):
162164
raise MessageBearingError(
163165
"The submitted file doesn't appear to be text",
@@ -168,7 +170,7 @@ def raise_if_not_sensible_file(self, resource: URI, entity_name:str):
168170
failure_type="submission",
169171
error_location="Whole File",
170172
error_code="MalformedFile",
171-
error_message=f"The submitted resource doesn't seem to be a valid text file",
173+
error_message="The resource doesn't seem to be a valid text file",
172174
)
173175
],
174176
)

src/dve/core_engine/backends/implementations/duckdb/readers/csv.py

Lines changed: 20 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
get_duckdb_type_from_annotation,
1717
)
1818
from dve.core_engine.backends.implementations.duckdb.types import SQLType
19-
from dve.core_engine.backends.implementations.duckdb.utilities import check_csv_header_expected
19+
from dve.core_engine.backends.readers.utilities import check_csv_header_expected
2020
from dve.core_engine.backends.utilities import get_polars_type_from_annotation
2121
from dve.core_engine.message import FeedbackMessage
2222
from dve.core_engine.type_hints import URI, EntityName
@@ -25,7 +25,14 @@
2525

2626
@duckdb_write_parquet
2727
class DuckDBCSVReader(BaseFileReader):
28-
"""A reader for CSV files"""
28+
"""A reader for CSV files including the ability to compare the passed model
29+
to the file header, if it exists.
30+
31+
field_check: flag to compare submitted file header to the accompanying pydantic model
32+
field_check_error_code: The error code to provide if the file header doesn't contain
33+
the expected fields
34+
field_check_error_message: The error message to provide if the file header doesn't contain
35+
the expected fields"""
2936

3037
# TODO - the read_to_relation should include the schema and determine whether to
3138
# TODO - stringify or not
@@ -54,14 +61,11 @@ def __init__(
5461
def perform_field_check(
5562
self, resource: URI, entity_name: str, expected_schema: type[BaseModel]
5663
):
64+
"""Check that the header of the CSV aligns with the provided model"""
5765
if not self.header:
5866
raise ValueError("Cannot perform field check without a CSV header")
5967

60-
if missing := check_csv_header_expected(
61-
resource,
62-
expected_schema,
63-
self.delim
64-
):
68+
if missing := check_csv_header_expected(resource, expected_schema, self.delim):
6569
raise MessageBearingError(
6670
"The CSV header doesn't match what is expected",
6771
messages=[
@@ -71,7 +75,7 @@ def perform_field_check(
7175
failure_type="submission",
7276
error_location="Whole File",
7377
error_code=self.field_check_error_code,
74-
error_message=f"{self.field_check_error_message} - missing fields: {missing}",
78+
error_message=f"{self.field_check_error_message} - missing fields: {missing}", # pylint: disable=line-too-long
7579
)
7680
],
7781
)
@@ -171,9 +175,14 @@ class DuckDBCSVRepeatingHeaderReader(PolarsToDuckDBCSVReader):
171175
"""
172176

173177
def __init__(
174-
self, non_unique_header_error_code: Optional[str] = "NonUniqueHeader", *args, **kwargs
178+
self,
179+
*args,
180+
non_unique_header_error_code: Optional[str] = "NonUniqueHeader",
181+
non_unique_header_error_message: Optional[str] = None,
182+
**kwargs,
175183
):
176184
self._non_unique_header_code = non_unique_header_error_code
185+
self._non_unique_header_message = non_unique_header_error_message
177186
super().__init__(*args, **kwargs)
178187

179188
@read_function(DuckDBPyRelation)
@@ -200,6 +209,8 @@ def read_to_relation( # pylint: disable=unused-argument
200209
failure_type="submission",
201210
error_message=(
202211
f"Found {no_records} distinct combination of header values."
212+
if not self._non_unique_header_message
213+
else self._non_unique_header_message
203214
),
204215
error_location=entity_name,
205216
category="Bad file",

src/dve/core_engine/backends/implementations/duckdb/utilities.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,8 @@
11
"""Utility objects for use with duckdb backend"""
22

33
import itertools
4-
from typing import Optional
5-
6-
from pydantic import BaseModel
74

85
from dve.core_engine.backends.base.utilities import _split_multiexpr_string
9-
from dve.core_engine.backends.exceptions import MessageBearingError
10-
from dve.core_engine.message import FeedbackMessage
11-
from dve.core_engine.type_hints import URI
12-
from dve.parser.file_handling import open_stream
136

147

158
def parse_multiple_expressions(expressions) -> list[str]:
@@ -46,15 +39,3 @@ def multiexpr_string_to_columns(expressions: str) -> list[str]:
4639
"""
4740
expression_list = _split_multiexpr_string(expressions)
4841
return expr_array_to_columns(expression_list)
49-
50-
def check_csv_header_expected(
51-
resource: URI,
52-
expected_schema: type[BaseModel],
53-
delimiter: Optional[str] = ",",
54-
quote_char: str = '"') -> set[str]:
55-
"""Check the header of a CSV matches the expected fields"""
56-
with open_stream(resource) as fle:
57-
header_fields = fle.readline().rstrip().replace(quote_char,"").split(delimiter)
58-
expected_fields = expected_schema.__fields__.keys()
59-
return set(expected_fields).difference(header_fields)
60-
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
"""General utilities for file readers"""
2+
3+
from typing import Optional
4+
5+
from pydantic import BaseModel
6+
7+
from dve.core_engine.type_hints import URI
8+
from dve.parser.file_handling.service import open_stream
9+
10+
11+
def check_csv_header_expected(
12+
resource: URI,
13+
expected_schema: type[BaseModel],
14+
delimiter: Optional[str] = ",",
15+
quote_char: str = '"',
16+
) -> set[str]:
17+
"""Check the header of a CSV matches the expected fields"""
18+
with open_stream(resource) as fle:
19+
header_fields = fle.readline().rstrip().replace(quote_char, "").split(delimiter)
20+
expected_fields = expected_schema.__fields__.keys()
21+
return set(expected_fields).difference(header_fields)

src/dve/pipeline/foundry_ddb_pipeline.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,7 @@ def error_report(
109109
self._logger.exception(exc)
110110
sub_stats = None
111111
report_uri = None
112+
submission_status = submission_status if submission_status else SubmissionStatus()
112113
submission_status.processing_failed = True
113114
dump_processing_errors(
114115
fh.joinuri(self.processed_files_path, submission_info.submission_id),
Lines changed: 0 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,8 @@
1-
import tempfile
2-
import datetime as dt
3-
from pathlib import Path
4-
from uuid import uuid4
5-
from pydantic import BaseModel, create_model
61
import pytest
72

83
from dve.core_engine.backends.implementations.duckdb.utilities import (
94
expr_mapping_to_columns,
105
expr_array_to_columns,
11-
check_csv_header_expected,
126
)
137

148

@@ -60,43 +54,3 @@ def test_expr_array_to_columns(expressions: dict[str, str], expected: list[str])
6054
observed = expr_array_to_columns(expressions)
6155
assert observed == expected
6256

63-
64-
@pytest.mark.parametrize(
65-
["header_row", "delim", "schema", "expected"],
66-
[
67-
(
68-
"field1,field2,field3",
69-
",",
70-
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
71-
set(),
72-
),
73-
(
74-
"field2,field3,field1",
75-
",",
76-
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
77-
set(),
78-
),
79-
(
80-
"str_field|int_field|date_field|",
81-
",",
82-
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
83-
{"str_field","int_field","date_field"},
84-
),
85-
(
86-
'"str_field"|"int_field"|"date_field"',
87-
"|",
88-
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
89-
set(),
90-
),
91-
92-
],
93-
)
94-
def test_check_csv_header_expected(
95-
header_row: str, delim: str, schema: type[BaseModel], expected: set[str]
96-
):
97-
mdl = create_model("TestModel", **schema)
98-
with tempfile.TemporaryDirectory() as tmpdir:
99-
fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv")
100-
fle.open("w+").write(header_row)
101-
res = check_csv_header_expected(fle.as_posix(), mdl, delim)
102-
assert res == expected

tests/test_core_engine/test_backends/test_readers/test_ddb_json.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def test_ddb_json_reader_all_str(temp_json_file):
5757
expected_fields = [fld for fld in mdl.__fields__]
5858
reader = DuckDBJSONReader()
5959
rel: DuckDBPyRelation = reader.read_to_entity_type(
60-
DuckDBPyRelation, uri, "test", stringify_model(mdl)
60+
DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl)
6161
)
6262
assert rel.columns == expected_fields
6363
assert dict(zip(rel.columns, rel.dtypes)) == {fld: "VARCHAR" for fld in expected_fields}
@@ -68,7 +68,7 @@ def test_ddb_json_reader_cast(temp_json_file):
6868
uri, data, mdl = temp_json_file
6969
expected_fields = [fld for fld in mdl.__fields__]
7070
reader = DuckDBJSONReader()
71-
rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri, "test", mdl)
71+
rel: DuckDBPyRelation = reader.read_to_entity_type(DuckDBPyRelation, uri.as_posix(), "test", mdl)
7272

7373
assert rel.columns == expected_fields
7474
assert dict(zip(rel.columns, rel.dtypes)) == {
@@ -82,7 +82,7 @@ def test_ddb_csv_write_parquet(temp_json_file):
8282
uri, _, mdl = temp_json_file
8383
reader = DuckDBJSONReader()
8484
rel: DuckDBPyRelation = reader.read_to_entity_type(
85-
DuckDBPyRelation, uri, "test", stringify_model(mdl)
85+
DuckDBPyRelation, uri.as_posix(), "test", stringify_model(mdl)
8686
)
8787
target_loc: Path = uri.parent.joinpath("test_parquet.parquet").as_posix()
8888
reader.write_parquet(rel, target_loc)
Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
import datetime as dt
2+
from pathlib import Path
3+
import tempfile
4+
from uuid import uuid4
5+
6+
import pytest
7+
from pydantic import BaseModel, create_model
8+
9+
from dve.core_engine.backends.readers.utilities import check_csv_header_expected
10+
11+
@pytest.mark.parametrize(
12+
["header_row", "delim", "schema", "expected"],
13+
[
14+
(
15+
"field1,field2,field3",
16+
",",
17+
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
18+
set(),
19+
),
20+
(
21+
"field2,field3,field1",
22+
",",
23+
{"field1": (str, ...), "field2": (int, ...), "field3": (float, 1.2)},
24+
set(),
25+
),
26+
(
27+
"str_field|int_field|date_field|",
28+
",",
29+
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
30+
{"str_field","int_field","date_field"},
31+
),
32+
(
33+
'"str_field"|"int_field"|"date_field"',
34+
"|",
35+
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
36+
set(),
37+
),
38+
(
39+
'str_field,int_field,date_field\n',
40+
",",
41+
{"str_field": (str, ...), "int_field": (int, ...), "date_field": (dt.date, dt.date.today())},
42+
set(),
43+
),
44+
45+
],
46+
)
47+
def test_check_csv_header_expected(
48+
header_row: str, delim: str, schema: type[BaseModel], expected: set[str]
49+
):
50+
mdl = create_model("TestModel", **schema)
51+
with tempfile.TemporaryDirectory() as tmpdir:
52+
fle = Path(tmpdir).joinpath(f"test_file_{uuid4().hex}.csv")
53+
fle.open("w+").write(header_row)
54+
res = check_csv_header_expected(fle.as_posix(), mdl, delim)
55+
assert res == expected

0 commit comments

Comments
 (0)