-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathxml.py
More file actions
401 lines (331 loc) · 15.1 KB
/
xml.py
File metadata and controls
401 lines (331 loc) · 15.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
# mypy: disable-error-code="attr-defined"
"""XML parsers for the Data Validation Engine."""
import re
from collections.abc import Collection, Iterator
from typing import IO, Any, GenericAlias, Optional, Union, overload # type: ignore
import polars as pl
from lxml import etree # type: ignore
from pydantic import BaseModel, create_model
from typing_extensions import Annotated, Protocol, get_args, get_origin
from dve.core_engine.backends.base.reader import BaseFileReader
from dve.core_engine.backends.exceptions import EmptyFileError
from dve.core_engine.backends.utilities import get_polars_type_from_annotation, stringify_model
from dve.core_engine.loggers import get_logger
from dve.core_engine.type_hints import URI, EntityName
from dve.parser.file_handling import NonClosingTextIOWrapper, get_content_length, open_stream
from dve.parser.file_handling.implementations.file import (
LocalFilesystemImplementation,
file_uri_to_local_path,
)
from dve.parser.file_handling.service import _get_implementation
XMLType = Union[Optional[str], list["XMLType"], dict[str, "XMLType"]] # type: ignore
"""The definition of a type within XML."""
XMLRecord = dict[str, XMLType] # type: ignore
"""A record within XML."""
TemplateElement = Union[None, list["TemplateElement"], dict[str, "TemplateElement"]] # type: ignore
"""The base types used in the template row."""
TemplateRow = dict[str, "TemplateElement"] # type: ignore
"""The type of a template row."""
def _strip_annotated(annotation: Any) -> Any:
"""Strip the 'Annotated' wrapper from a type."""
origin = get_origin(annotation)
if origin is None or origin is not Annotated:
return annotation
return get_args(annotation)[0]
def create_template_row(schema: type[BaseModel]) -> dict[str, Any]:
"""Create a template row from a schema. A template row is essentially the
shape of the record that would be populated by the reader (i.e. contains
default values), except lists are pre-populated with a single 'empty'
record as a hint to the reader about the data structure.
"""
template_row: dict[str, Any] = {}
for field_name, model_field_def in schema.__fields__.items():
field_type = _strip_annotated(model_field_def.annotation)
if not model_field_def.is_complex():
template_row[field_name] = None
continue
if isinstance(field_type, type) and not isinstance(field_type, GenericAlias):
if issubclass(field_type, BaseModel):
template_row[field_name] = create_template_row(field_type)
continue
raise TypeError(f"Cannot read arbitrary complex type from XML, got {field_type!r}")
origin = get_origin(field_type)
if origin is list:
list_type = _strip_annotated(get_args(field_type)[0])
# This is a quick and dirty hack to avoid implementing our own logic
# to check complex types...
list_type_field_spec = create_model("", lt=(list_type, ...)).__fields__["lt"]
if not list_type_field_spec.is_complex():
template_row[field_name] = [None]
continue
if isinstance(list_type, type) and issubclass(list_type, BaseModel):
template_row[field_name] = [create_template_row(list_type)]
continue
raise TypeError(f"Cannot read arbitrary complex type from XML, got {field_type!r}")
return template_row
class XMLElement(Protocol):
"""A description of an element in an XML document.
This is used because we could in theory use `lxml` or the standard library
`ElementTree` interchangeably and while these have the same strucure the
type hints for `lxml` are rubbish.
"""
tag: Optional[str]
"""The XML element's tag."""
text: Optional[str]
"""The text inside the XML element's tags."""
def clear(self) -> None:
"""Clear the element, removing children/attrs/etc."""
def __iter__(self) -> Iterator["XMLElement"]: ...
class BasicXMLFileReader(BaseFileReader):
"""A reader for XML files built atop LXML."""
def __init__(
self,
*,
record_tag: str,
root_tag: Optional[str] = None,
trim_cells: bool = True,
null_values: Collection[str] = frozenset({"NULL", "null", ""}),
sanitise_multiline: bool = True,
encoding: str = "utf-8-sig",
n_records_to_read: Optional[int] = None,
**_,
):
"""Init function for the base XML reader.
Args:
- `record_tag`: a required string indicating the tag of each 'record'
in the XML document.
- `root_tag`: a string indicating the tag to find the records in within
the XML document. If `None`, assume that the records are under the root
node.
- `trim_cells`: a boolean value indicating whether to strip whitespace
from elements in the XML document. Default: `True`
- `null_values`: a container of values to replace with null if encountered
in an element. Default: `{'', 'null', 'NULL'}`
- `sanitise_multiline`: whether to sanitise (remove newlines and multiple
spaces) from multiline fields.
- `encoding`: encoding of the XML file. Default: `utf-8-sig`
- `n_records_to_read`: the maximum number of records to read from a document.
"""
self.record_tag = record_tag
"""The name of a 'record' tag in the XML document."""
self.root_tag = root_tag
"""The name of the 'root' tag in the XML document."""
self.trim_cells = trim_cells
"""A boolean value indicating whether to strip whitespace from fields."""
self.null_values = null_values
"""A container of values to replace with null if encountered in an element."""
self.sanitise_multiline = sanitise_multiline
"""A boolean value indicating whether to sanitise multiline fields."""
self.encoding = encoding
"""Encoding of the XML file."""
self.n_records_to_read = n_records_to_read
"""The maximum number of records to read from a document."""
super().__init__()
self._logger = get_logger(__name__)
def _strip_namespace(self, element: XMLElement) -> None:
"""Mutate an element and strip the namespace."""
if element.tag is not None:
element.tag = re.sub(r"^(\{.+\}|.+:)", "", element.tag)
def _strip_namespaces_recursively(self, element: XMLElement) -> None:
"""Mutate an element and its children and strip their namespaces."""
self._strip_namespace(element)
for child in element:
self._strip_namespaces_recursively(child)
def _sanitise_field(self, value: Optional[str]) -> Optional[str]:
"""Sanitise a field value from an XML document."""
if value is not None:
if self.trim_cells:
value = value.strip()
if self.sanitise_multiline:
value = re.sub("\\s*\n\\s*", " ", value, flags=re.MULTILINE)
if value in self.null_values:
value = None
return value
@overload
def _parse_element(self, element: XMLElement, template: TemplateRow) -> XMLRecord: ...
@overload
def _parse_element(self, element: XMLElement, template: TemplateElement) -> XMLType: ...
def _parse_element(self, element: XMLElement, template: Union[TemplateElement, TemplateRow]):
"""Parse an XML element according to a template."""
if template is None:
return self._sanitise_field(element.text)
if isinstance(template, list):
return [self._parse_element(element, template[0])]
record: XMLRecord = {}
for child in element:
tag = child.tag
if tag is None or tag not in template:
continue
template_element = template[tag]
if isinstance(template_element, list):
if tag in record:
current_value = record[tag]
if isinstance(current_value, list):
source_list = current_value
else:
source_list = [current_value]
record[tag] = source_list
else:
source_list = []
record[tag] = source_list
source_list.append(self._parse_element(child, template_element[0]))
else:
record[tag] = self._parse_element(child, template_element)
for missing_key in template.keys() - record.keys():
record[missing_key] = None
return record
def _get_elements_from_stream(self, stream: IO[bytes]) -> Iterator[XMLElement]:
"""Get an iterator of records from the tree as XML elements."""
encoding = self.encoding if self.encoding != "utf-8-sig" else "utf-8"
parser = etree.XMLParser(
encoding=encoding,
remove_pis=True,
remove_comments=True,
dtd_validation=False,
resolve_entities=False,
)
tree: etree._ElementTree = etree.parse(stream, parser)
root: etree._Element = tree.getroot()
elements: list[XMLElement]
if self.root_tag:
elements = root.xpath(
f"//*[local-name()='{self.root_tag}']/*[local-name()='{self.record_tag}']"
)
else:
elements = root.xpath(f"//*[local-name()='{self.record_tag}']")
element_count = 0
for element in elements:
self._strip_namespaces_recursively(element)
yield element
element_count += 1
if self.n_records_to_read and element_count == self.n_records_to_read:
break
def _parse_xml(
self, stream: IO[bytes], schema: type[BaseModel]
) -> Iterator[dict[str, XMLType]]:
"""Coerce a parsed record into the intended shape, nulling values
which are expected to be parsed as nulls.
"""
elements = self._get_elements_from_stream(stream)
template_row = create_template_row(schema)
for element in elements:
yield self._parse_element(element, template_row)
def read_to_py_iterator(
self,
resource: URI,
entity_name: EntityName,
schema: type[BaseModel],
) -> Iterator[dict[str, Any]]:
"""Iterate through the contents of the file at URI, yielding rows
containing the data.
Field names can be slightly more complex for XML: to indicate a
level of nesting, use a dot to separate levels.
For arrays (which may occur none, 1, or many times), enclose the
whole field name in square brackets. See
`parser.utilities.parse_default_row`
"""
if get_content_length(resource) == 0:
raise EmptyFileError(f"File at {resource!r} is empty")
with open_stream(resource, "rb") as stream:
yield from self._parse_xml(stream, schema)
def write_parquet( # type: ignore
self,
entity: Iterator[dict[str, Any]],
target_location: URI,
schema: Optional[type[BaseModel]] = None,
**kwargs,
) -> URI:
"""Writes the data of the given entity out to a parquet file"""
# polars misinterprets local file schemes and creates a file: folder.
# parse it as Path and uri to resolve
if isinstance(_get_implementation(target_location), LocalFilesystemImplementation):
target_location = file_uri_to_local_path(target_location).as_posix()
if schema:
polars_schema: dict[str, pl.DataType] = { # type: ignore
fld.name: get_polars_type_from_annotation(fld.type_)
for fld in stringify_model(schema).__fields__.values()
}
pl.LazyFrame(data=entity, schema=polars_schema).sink_parquet(
path=target_location, compression="snappy", **kwargs
)
else:
pl.LazyFrame(data=entity).sink_parquet(
path=target_location, compression="snappy", **kwargs
)
return target_location
class XMLStreamReader(BasicXMLFileReader):
"""An XML parser which 'streams' the file, parsing only specific records.
This means it requires more configuration, but should be much kinder on memory.
"""
def _get_elements_from_stream(self, stream: IO[bytes]) -> Iterator[XMLElement]:
parser = etree.XMLPullParser(
events=(
"start",
"end",
),
remove_pis=True,
remove_comments=True,
dtd_validation=False,
resolve_entities=False,
)
container_contexts = 1 if not self.root_tag else 0
record_contexts = 0
emitted_element_count = 0
self._logger.debug(
f"Starting to parse XML stream, state is {'closed' if stream.closed else 'open'}"
)
element: XMLElement
with NonClosingTextIOWrapper(stream, encoding=self.encoding) as text_stream:
self._logger.debug(
f"Starting text stream, state is {'closed' if text_stream.closed else 'open'}"
)
while not text_stream.closed:
text_block = text_stream.read(500_000)
if not text_block:
break
parser.feed(text_block)
for action, element in parser.read_events():
self._strip_namespace(element)
if action == "start":
if self.root_tag and element.tag == self.root_tag:
container_contexts += 1
if element.tag == self.record_tag:
record_contexts += 1
continue
if action != "end":
continue
if not container_contexts:
continue
if self.root_tag and element.tag == self.root_tag:
container_contexts -= 1
if element.tag == self.record_tag:
self._strip_namespaces_recursively(element)
yield element
emitted_element_count += 1
record_contexts -= 1
if not record_contexts:
element.clear()
if (
self.n_records_to_read is not None
and emitted_element_count == self.n_records_to_read
):
break
else:
continue
break
try:
parser.close()
# We don't care if the XML is incomplete.
except Exception: # pylint: disable=broad-except
pass
def write_parquet( # type: ignore
self,
entity: Iterator[dict[str, Any]],
target_location: URI,
schema: Optional[type[BaseModel]] = None,
**kwargs,
) -> URI:
"""Writes the given entity data out to a parquet file"""
return super().write_parquet(
entity=entity, target_location=target_location, schema=schema, **kwargs
)