|
2 | 2 |
|
3 | 3 | from collections.abc import Mapping |
4 | 4 | from typing import Any, cast |
| 5 | +from xml.parsers.expat import ExpatError |
5 | 6 |
|
6 | 7 | import xmltodict |
7 | 8 |
|
@@ -71,7 +72,7 @@ def publish(self, path: str, files: Mapping[str, Any] | None) -> int: |
71 | 72 | If the server returns an error during upload. |
72 | 73 | """ |
73 | 74 | response = self._http.post(path, files=files) |
74 | | - parsed_response = xmltodict.parse(response.content) |
| 75 | + parsed_response = self._parse_xml_response(response.content) |
75 | 76 | return self._extract_id_from_upload(parsed_response) |
76 | 77 |
|
77 | 78 | def delete(self, resource_id: int) -> bool: |
@@ -106,7 +107,7 @@ def delete(self, resource_id: int) -> bool: |
106 | 107 | path = f"{endpoint_name}/{resource_id}" |
107 | 108 | try: |
108 | 109 | response = self._http.delete(path) |
109 | | - result = xmltodict.parse(response.content) |
| 110 | + result = self._parse_xml_response(response.content) |
110 | 111 | return f"oml:{endpoint_name}_delete" in result |
111 | 112 | except OpenMLServerException as e: |
112 | 113 | self._handle_delete_exception(endpoint_name, e) |
@@ -143,7 +144,7 @@ def tag(self, resource_id: int, tag: str) -> list[str]: |
143 | 144 | data = {f"{endpoint_name}_id": resource_id, "tag": tag} |
144 | 145 | response = self._http.post(path, data=data) |
145 | 146 |
|
146 | | - parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) |
| 147 | + parsed_response = self._parse_xml_response(response.content, force_list={"oml:tag"}) |
147 | 148 | result = parsed_response[f"oml:{endpoint_name}_tag"] |
148 | 149 | tags: list[str] = result.get("oml:tag", []) |
149 | 150 |
|
@@ -180,12 +181,28 @@ def untag(self, resource_id: int, tag: str) -> list[str]: |
180 | 181 | data = {f"{endpoint_name}_id": resource_id, "tag": tag} |
181 | 182 | response = self._http.post(path, data=data) |
182 | 183 |
|
183 | | - parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"}) |
| 184 | + parsed_response = self._parse_xml_response(response.content, force_list={"oml:tag"}) |
184 | 185 | result = parsed_response[f"oml:{endpoint_name}_untag"] |
185 | 186 | tags: list[str] = result.get("oml:tag", []) |
186 | 187 |
|
187 | 188 | return tags |
188 | 189 |
|
| 190 | + def _parse_xml_response(self, payload: bytes | str, **kwargs: Any) -> Mapping[str, Any]: |
| 191 | + try: |
| 192 | + return cast("Mapping[str, Any]", xmltodict.parse(payload, **kwargs)) |
| 193 | + except ExpatError: |
| 194 | + payload_text = ( |
| 195 | + payload.decode("utf-8", errors="ignore") if isinstance(payload, bytes) else payload |
| 196 | + ) |
| 197 | + xml_start = payload_text.find("<?xml") |
| 198 | + if xml_start == -1: |
| 199 | + xml_start = payload_text.find("<oml:") |
| 200 | + if xml_start == -1: |
| 201 | + raise |
| 202 | + |
| 203 | + xml_text = payload_text[xml_start:] |
| 204 | + return cast("Mapping[str, Any]", xmltodict.parse(xml_text, **kwargs)) |
| 205 | + |
189 | 206 | def _get_endpoint_name(self) -> str: |
190 | 207 | if self.resource_type == ResourceType.DATASET: |
191 | 208 | return "data" |
|
0 commit comments