Skip to content

Commit 808bc5b

Browse files
committed
updated the test
1 parent dc081ca commit 808bc5b

1 file changed

Lines changed: 21 additions & 4 deletions

File tree

openml/_api/resources/base/versions.py

Lines changed: 21 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from collections.abc import Mapping
44
from typing import Any, cast
5+
from xml.parsers.expat import ExpatError
56

67
import xmltodict
78

@@ -71,7 +72,7 @@ def publish(self, path: str, files: Mapping[str, Any] | None) -> int:
7172
If the server returns an error during upload.
7273
"""
7374
response = self._http.post(path, files=files)
74-
parsed_response = xmltodict.parse(response.content)
75+
parsed_response = self._parse_xml_response(response.content)
7576
return self._extract_id_from_upload(parsed_response)
7677

7778
def delete(self, resource_id: int) -> bool:
@@ -106,7 +107,7 @@ def delete(self, resource_id: int) -> bool:
106107
path = f"{endpoint_name}/{resource_id}"
107108
try:
108109
response = self._http.delete(path)
109-
result = xmltodict.parse(response.content)
110+
result = self._parse_xml_response(response.content)
110111
return f"oml:{endpoint_name}_delete" in result
111112
except OpenMLServerException as e:
112113
self._handle_delete_exception(endpoint_name, e)
@@ -143,7 +144,7 @@ def tag(self, resource_id: int, tag: str) -> list[str]:
143144
data = {f"{endpoint_name}_id": resource_id, "tag": tag}
144145
response = self._http.post(path, data=data)
145146

146-
parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"})
147+
parsed_response = self._parse_xml_response(response.content, force_list={"oml:tag"})
147148
result = parsed_response[f"oml:{endpoint_name}_tag"]
148149
tags: list[str] = result.get("oml:tag", [])
149150

@@ -180,12 +181,28 @@ def untag(self, resource_id: int, tag: str) -> list[str]:
180181
data = {f"{endpoint_name}_id": resource_id, "tag": tag}
181182
response = self._http.post(path, data=data)
182183

183-
parsed_response = xmltodict.parse(response.content, force_list={"oml:tag"})
184+
parsed_response = self._parse_xml_response(response.content, force_list={"oml:tag"})
184185
result = parsed_response[f"oml:{endpoint_name}_untag"]
185186
tags: list[str] = result.get("oml:tag", [])
186187

187188
return tags
188189

190+
def _parse_xml_response(self, payload: bytes | str, **kwargs: Any) -> Mapping[str, Any]:
191+
try:
192+
return cast("Mapping[str, Any]", xmltodict.parse(payload, **kwargs))
193+
except ExpatError:
194+
payload_text = (
195+
payload.decode("utf-8", errors="ignore") if isinstance(payload, bytes) else payload
196+
)
197+
xml_start = payload_text.find("<?xml")
198+
if xml_start == -1:
199+
xml_start = payload_text.find("<oml:")
200+
if xml_start == -1:
201+
raise
202+
203+
xml_text = payload_text[xml_start:]
204+
return cast("Mapping[str, Any]", xmltodict.parse(xml_text, **kwargs))
205+
189206
def _get_endpoint_name(self) -> str:
190207
if self.resource_type == ResourceType.DATASET:
191208
return "data"

0 commit comments

Comments
 (0)