|
14 | 14 | from collections import defaultdict |
15 | 15 | from glob import glob |
16 | 16 | from pathlib import Path |
17 | | -from typing import Any, Mapping, Sequence |
| 17 | +from typing import Any, Mapping, Sequence, TypeVar, overload |
18 | 18 |
|
19 | 19 | import packaging.version |
20 | 20 | import typing_inspect |
@@ -183,7 +183,7 @@ def ref_type(self, obj): |
183 | 183 | return self.get_ref_type(obj["$ref"]) |
184 | 184 |
|
185 | 185 |
|
186 | | -CLASSES = {} |
| 186 | +CLASSES: dict[str, type[Type]] = {} |
187 | 187 | factories = codegen.Capture() |
188 | 188 |
|
189 | 189 |
|
@@ -479,37 +479,48 @@ def ReturnMapping(cls): # noqa: N802 |
479 | 479 | def decorator(f): |
480 | 480 | @functools.wraps(f) |
481 | 481 | async def wrapper(*args, **kwargs): |
482 | | - nonlocal cls |
483 | 482 | reply = await f(*args, **kwargs) |
484 | | - if cls is None: |
485 | | - return reply |
486 | | - if "error" in reply: |
487 | | - cls = CLASSES["Error"] |
488 | | - if typing_inspect.is_generic_type(cls) and issubclass( |
489 | | - typing_inspect.get_origin(cls), Sequence |
490 | | - ): |
491 | | - parameters = typing_inspect.get_parameters(cls) |
492 | | - result = [] |
493 | | - item_cls = parameters[0] |
494 | | - for item in reply: |
495 | | - result.append(item_cls.from_json(item)) |
496 | | - """ |
497 | | - if 'error' in item: |
498 | | - cls = CLASSES['Error'] |
499 | | - else: |
500 | | - cls = item_cls |
501 | | - result.append(cls.from_json(item)) |
502 | | - """ |
503 | | - else: |
504 | | - result = cls.from_json(reply["response"]) |
505 | | - |
506 | | - return result |
| 483 | + return _convert_response(reply, cls=cls) |
507 | 484 |
|
508 | 485 | return wrapper |
509 | 486 |
|
510 | 487 | return decorator |
511 | 488 |
|
512 | 489 |
|
| 490 | +@overload |
| 491 | +def _convert_response(response: dict[str, Any], *, cls: type[SomeType]) -> SomeType: ... |
| 492 | + |
| 493 | + |
| 494 | +@overload |
| 495 | +def _convert_response(response: dict[str, Any], *, cls: None) -> dict[str, Any]: ... |
| 496 | + |
| 497 | + |
| 498 | +def _convert_response(response: dict[str, Any], *, cls: type[Type] | None) -> Any: |
| 499 | + if cls is None: |
| 500 | + return response |
| 501 | + if "error" in response: |
| 502 | + cls = CLASSES["Error"] |
| 503 | + if typing_inspect.is_generic_type(cls) and issubclass( |
| 504 | + typing_inspect.get_origin(cls), Sequence |
| 505 | + ): |
| 506 | + parameters = typing_inspect.get_parameters(cls) |
| 507 | + result = [] |
| 508 | + item_cls = parameters[0] |
| 509 | + for item in response: |
| 510 | + result.append(item_cls.from_json(item)) |
| 511 | + """ |
| 512 | + if 'error' in item: |
| 513 | + cls = CLASSES['Error'] |
| 514 | + else: |
| 515 | + cls = item_cls |
| 516 | + result.append(cls.from_json(item)) |
| 517 | + """ |
| 518 | + else: |
| 519 | + result = cls.from_json(response["response"]) |
| 520 | + |
| 521 | + return result |
| 522 | + |
| 523 | + |
513 | 524 | def make_func(cls, name, description, params, result, _async=True): |
514 | 525 | indent = " " |
515 | 526 | args = Args(cls.schema, params) |
@@ -663,7 +674,7 @@ async def rpc(self, msg: dict[str, _RichJson]) -> _Json: |
663 | 674 | return result |
664 | 675 |
|
665 | 676 | @classmethod |
666 | | - def from_json(cls, data): |
| 677 | + def from_json(cls, data: Type | str | dict[str, Any] | list[Any]) -> Type | None: |
667 | 678 | def _parse_nested_list_entry(expr, result_dict): |
668 | 679 | if isinstance(expr, str): |
669 | 680 | if ">" in expr or ">=" in expr: |
@@ -742,6 +753,9 @@ def get(self, key, default=None): |
742 | 753 | return getattr(self, attr, default) |
743 | 754 |
|
744 | 755 |
|
| 756 | +SomeType = TypeVar("SomeType", bound=Type) |
| 757 | + |
| 758 | + |
745 | 759 | class Schema(dict): |
746 | 760 | def __init__(self, schema): |
747 | 761 | self.name = schema["Name"] |
|
0 commit comments