Skip to content

Commit f1fd8aa

Browse files
committed
style: run black, isort and resolve linting issues
1 parent efe0fa0 commit f1fd8aa

16 files changed

Lines changed: 195 additions & 135 deletions

File tree

src/dve/common/error_utils.py

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,29 @@
11
"""Utilities to support reporting"""
22

33
import datetime as dt
4-
from itertools import chain
54
import json
65
import logging
6+
from collections.abc import Iterable
7+
from itertools import chain
78
from multiprocessing import Queue
89
from threading import Thread
9-
from typing import Iterable, Iterator, Optional, Union
10+
from typing import Optional, Union
1011

11-
from dve.core_engine.message import UserMessage
12-
from dve.core_engine.loggers import get_logger
1312
import dve.parser.file_handling as fh
1413
from dve.core_engine.exceptions import CriticalProcessingError
14+
from dve.core_engine.loggers import get_logger
15+
from dve.core_engine.message import UserMessage
1516
from dve.core_engine.type_hints import URI, DVEStage, Messages
1617

1718

1819
def get_feedback_errors_uri(working_folder: URI, step_name: DVEStage) -> URI:
1920
"""Determine the location of json lines file containing all errors generated in a step"""
2021
return fh.joinuri(working_folder, "errors", f"{step_name}_errors.jsonl")
2122

23+
2224
def get_processing_errors_uri(working_folder: URI) -> URI:
2325
"""Determine the location of json lines file containing all processing
24-
errors generated from DVE run"""
26+
errors generated from DVE run"""
2527
return fh.joinuri(working_folder, "errors", "processing_errors.jsonl")
2628

2729

@@ -88,74 +90,85 @@ def dump_processing_errors(
8890
)
8991

9092
with fh.open_stream(error_file, "a") as f:
91-
f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n")
92-
93+
f.write("\n".join([json.dumps(rec, default=str) for rec in processed]) + "\n")
94+
9395
return error_file
9496

97+
9598
def load_feedback_messages(feedback_messages_uri: URI) -> Iterable[UserMessage]:
99+
"""Load user messages from jsonl file"""
96100
if not fh.get_resource_exists(feedback_messages_uri):
97101
return
98102
with fh.open_stream(feedback_messages_uri) as errs:
99103
yield from (UserMessage(**json.loads(err)) for err in errs.readlines())
100104

105+
101106
def load_all_error_messages(error_directory_uri: URI) -> Iterable[UserMessage]:
102-
return chain.from_iterable([load_feedback_messages(err_file) for err_file, _ in fh.iter_prefix(error_directory_uri) if err_file.endswith(".jsonl")])
107+
"Load user messages from all jsonl files"
108+
return chain.from_iterable(
109+
[
110+
load_feedback_messages(err_file)
111+
for err_file, _ in fh.iter_prefix(error_directory_uri)
112+
if err_file.endswith(".jsonl")
113+
]
114+
)
115+
103116

104117
class BackgroundMessageWriter:
105-
def __init__(self,
106-
working_directory: URI,
107-
dve_stage: DVEStage,
108-
key_fields: Optional[dict[str, list[str]]] = None,
109-
logger: Optional[logging.Logger] = None):
118+
"""Controls batch writes to error jsonl files"""
119+
120+
def __init__(
121+
self,
122+
working_directory: URI,
123+
dve_stage: DVEStage,
124+
key_fields: Optional[dict[str, list[str]]] = None,
125+
logger: Optional[logging.Logger] = None,
126+
):
110127
self._working_directory = working_directory
111128
self._dve_stage = dve_stage
112-
self._feedback_message_uri = get_feedback_errors_uri(self._working_directory, self._dve_stage)
129+
self._feedback_message_uri = get_feedback_errors_uri(
130+
self._working_directory, self._dve_stage
131+
)
113132
self._key_fields = key_fields
114133
self.logger = logger or get_logger(type(self).__name__)
115134
self._write_thread = None
116135
self._queue = Queue()
117-
136+
118137
@property
119-
def write_queue(self):
138+
def write_queue(self) -> Queue: # type: ignore
139+
"""Queue for storing batches of messages to be written"""
120140
return self._queue
121-
141+
122142
@property
123-
def write_thread(self):
143+
def write_thread(self) -> Thread: # type: ignore
144+
"""Thread to write batches of messages to jsonl file"""
124145
if not self._write_thread:
125146
self._write_thread = Thread(target=self._write_process_wrapper)
126147
return self._write_thread
127-
128-
148+
129149
def _write_process_wrapper(self):
130150
"""Wrapper for dump feedback errors to run in background process"""
131151
while True:
132152
if msgs := self.write_queue.get():
133-
dump_feedback_errors(self._working_directory, self._dve_stage, msgs, self._key_fields)
153+
dump_feedback_errors(
154+
self._working_directory, self._dve_stage, msgs, self._key_fields
155+
)
134156
else:
135157
break
136-
158+
137159
def __enter__(self) -> "BackgroundMessageWriter":
138160
self.write_thread.start()
139161
return self
140-
162+
141163
def __exit__(self, exc_type, exc_value, traceback):
142164
if exc_type:
143165
self.logger.exception(
144166
"Issue occured during background write process:",
145-
exc_info=(exc_type, exc_value, traceback)
167+
exc_info=(exc_type, exc_value, traceback),
146168
)
147169
self.write_queue.put(None)
148170
self.write_thread.join()
149-
150-
151-
152-
def write_process_wrapper(working_directory: URI, *, queue: Queue, key_fields: Optional[dict[str, list[str]]] = None):
153-
"""Wrapper for dump feedback errors to run in background process"""
154-
while True:
155-
if msgs := queue.get():
156-
dump_feedback_errors(fh.joinuri(working_directory, "data_contract"), msgs, key_fields)
157-
else:
158-
break
171+
159172

160173
def conditional_cast(value, primary_keys: list[str], value_separator: str) -> Union[list[str], str]:
161174
"""Determines what to do with a value coming back from the error list"""

src/dve/core_engine/backends/base/backend.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,7 @@
1717
from dve.core_engine.backends.types import Entities, EntityType, StageSuccessful
1818
from dve.core_engine.loggers import get_logger
1919
from dve.core_engine.models import SubmissionInfo
20-
from dve.core_engine.type_hints import (
21-
URI,
22-
EntityLocations,
23-
EntityName,
24-
EntityParquetLocations,
25-
Messages,
26-
)
20+
from dve.core_engine.type_hints import URI, EntityLocations, EntityName, EntityParquetLocations
2721
from dve.parser.file_handling.service import get_parent, joinuri
2822

2923

@@ -162,7 +156,9 @@ def apply(
162156
reference_data = self.load_reference_data(
163157
rule_metadata.reference_data_config, submission_info
164158
)
165-
entities, dc_feedback_errors_uri, successful, processing_errors_uri = self.contract.apply(working_dir, entity_locations, contract_metadata)
159+
entities, dc_feedback_errors_uri, successful, processing_errors_uri = self.contract.apply(
160+
working_dir, entity_locations, contract_metadata
161+
)
166162
if not successful:
167163
return entities, dc_feedback_errors_uri, successful, processing_errors_uri
168164

@@ -172,7 +168,8 @@ def apply(
172168
# TODO: Handle entity manager creation errors.
173169
entity_manager = EntityManager(entities, reference_data)
174170
# TODO: Add stage success to 'apply_rules'
175-
# TODO: In case of large errors in business rules, write messages to jsonl file and return uri to errors
171+
# TODO: In case of large errors in business rules, write messages to jsonl file
172+
# TODO: and return uri to errors
176173
_ = self.step_implementations.apply_rules(entity_manager, rule_metadata)
177174

178175
for entity_name, entity in entity_manager.entities.items():
@@ -196,7 +193,9 @@ def process(
196193
working_dir, entity_locations, contract_metadata, rule_metadata, submission_info
197194
)
198195
if successful:
199-
parquet_locations = self.write_entities_to_parquet(entities, joinuri(working_dir, "outputs"))
196+
parquet_locations = self.write_entities_to_parquet(
197+
entities, joinuri(working_dir, "outputs")
198+
)
200199
else:
201200
parquet_locations = {}
202201
return parquet_locations, feedback_errors_uri, processing_errors_uri
@@ -234,6 +233,8 @@ def process_legacy(
234233
return entities, errors_uri # type: ignore
235234

236235
return (
237-
self.convert_entities_to_spark(entities, joinuri(working_dir, "outputs"), _emit_deprecation_warning=False),
236+
self.convert_entities_to_spark(
237+
entities, joinuri(working_dir, "outputs"), _emit_deprecation_warning=False
238+
),
238239
errors_uri,
239240
)

src/dve/core_engine/backends/base/contract.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,11 @@
99
from pydantic import BaseModel
1010
from typing_extensions import Protocol
1111

12+
from dve.common.error_utils import (
13+
dump_processing_errors,
14+
get_feedback_errors_uri,
15+
get_processing_errors_uri,
16+
)
1217
from dve.core_engine.backends.base.core import get_entity_type
1318
from dve.core_engine.backends.base.reader import BaseFileReader
1419
from dve.core_engine.backends.exceptions import ReaderLacksEntityTypeSupport, render_error
@@ -27,10 +32,8 @@
2732
Messages,
2833
WrapDecorator,
2934
)
30-
from dve.parser.file_handling import get_file_suffix, get_resource_exists, get_parent
31-
from dve.parser.file_handling.service import joinuri
35+
from dve.parser.file_handling import get_file_suffix, get_resource_exists
3236
from dve.parser.type_hints import Extension
33-
from dve.common.error_utils import dump_processing_errors, get_feedback_errors_uri, get_processing_errors_uri
3437

3538
T = TypeVar("T")
3639
ExtensionConfig = dict[Extension, "ReaderConfig"]
@@ -362,7 +365,12 @@ def read_raw_entities(
362365

363366
@abstractmethod
364367
def apply_data_contract(
365-
self, working_dir: URI, entities: Entities, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, key_fields: Optional[dict[str, list[str]]] = None
368+
self,
369+
working_dir: URI,
370+
entities: Entities,
371+
entity_locations: EntityLocations,
372+
contract_metadata: DataContractMetadata,
373+
key_fields: Optional[dict[str, list[str]]] = None,
366374
) -> tuple[Entities, URI, StageSuccessful]:
367375
"""Apply the data contract to the raw entities, returning the validated entities
368376
and any messages.
@@ -373,7 +381,11 @@ def apply_data_contract(
373381
raise NotImplementedError()
374382

375383
def apply(
376-
self, working_dir: URI, entity_locations: EntityLocations, contract_metadata: DataContractMetadata, key_fields: Optional[dict[str, list[str]]] = None
384+
self,
385+
working_dir: URI,
386+
entity_locations: EntityLocations,
387+
contract_metadata: DataContractMetadata,
388+
key_fields: Optional[dict[str, list[str]]] = None,
377389
) -> tuple[Entities, URI, StageSuccessful, URI]:
378390
"""Read the entities from the provided locations according to the data contract,
379391
and return the validated entities and any messages.

src/dve/core_engine/backends/base/utilities.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,10 +136,11 @@ def _get_non_heterogenous_type(types: Sequence[type]) -> type:
136136
)
137137
return type_list[0]
138138

139+
139140
def check_if_parquet_file(file_location: URI) -> bool:
140141
"""Check if a file path is valid parquet"""
141142
try:
142143
pq.ParquetFile(file_location)
143144
return True
144145
except (pyarrow.ArrowInvalid, pyarrow.ArrowIOError):
145-
return False
146+
return False

0 commit comments

Comments
 (0)