Skip to content

Commit a399af9

Browse files
committed
chore: better types for Type.rpc and Connection.rpc
1 parent 1a23a60 commit a399af9

3 files changed

Lines changed: 35 additions & 11 deletions

File tree

juju/client/connection.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,14 @@
1616
import macaroonbakery.httpbakery as httpbakery
1717
import websockets
1818
from dateutil.parser import parse
19-
from typing_extensions import Self
19+
from typing_extensions import Self, overload
2020

2121
from juju import errors, jasyncio, tag, utils
2222
from juju.client import client
2323
from juju.utils import IdQueue
2424
from juju.version import CLIENT_VERSION
2525

26+
from .facade import _JSON, _RICH_JSON, TypeEncoder
2627
from .facade_versions import client_facade_versions, known_unsupported_facades
2728

2829
LEVELS = ["TRACE", "DEBUG", "INFO", "WARNING", "ERROR"]
@@ -536,7 +537,19 @@ async def _do_ping():
536537
log.debug("ping failed because of closed connection")
537538
pass
538539

539-
async def rpc(self, msg: dict, encoder=None) -> dict:
540+
@overload
541+
async def rpc(
542+
self, msg: dict[str, _JSON], encoder: None = None
543+
) -> dict[str, _JSON]: ...
544+
545+
@overload
546+
async def rpc(
547+
self, msg: dict[str, _RICH_JSON], encoder: TypeEncoder
548+
) -> dict[str, _JSON]: ...
549+
550+
async def rpc(
551+
self, msg: dict[str, Any], encoder: json.JSONEncoder | None = None
552+
) -> dict[str, Any]:
540553
"""Make an RPC to the API. The message is encoded as JSON
541554
using the given encoder if any.
542555
:param msg: Parameters for the call (will be encoded as JSON).

juju/client/facade.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# Copyright 2023 Canonical Ltd.
22
# Licensed under the Apache V2, see LICENCE file for details.
3+
from __future__ import annotations
34

45
import argparse
56
import builtins
@@ -13,13 +14,22 @@
1314
from collections import defaultdict
1415
from glob import glob
1516
from pathlib import Path
16-
from typing import Any, Dict, List, Mapping, Sequence
17+
from typing import Any, Mapping, Sequence
1718

1819
import packaging.version
1920
import typing_inspect
21+
from typing_extensions import TypeAlias
2022

2123
from . import codegen
2224

25+
# Plain JSON, what is received from Juju
26+
_JSON_LEAF: TypeAlias = None | bool | int | float | str
27+
_JSON: TypeAlias = "_JSON_LEAF|list[_JSON]|dict[str, _JSON]"
28+
29+
# Type-enriched JSON, what can be sent to Juju
30+
_RICH_LEAF: TypeAlias = "_JSON_LEAF|Type"
31+
_RICH_JSON: TypeAlias = "_RICH_LEAF|list[_RICH_JSON]|dict[str, _RICH_JSON]"
32+
2333
_marker = object()
2434

2535
JUJU_VERSION = re.compile(r"[0-9]+\.[0-9-]+[\.\-][0-9a-z]+(\.[0-9]+)?")
@@ -634,7 +644,7 @@ class {name}Facade(Type):
634644

635645

636646
class TypeEncoder(json.JSONEncoder):
637-
def default(self, obj):
647+
def default(self, obj: _RICH_JSON) -> _JSON:
638648
if isinstance(obj, Type):
639649
return obj.serialize()
640650
return json.JSONEncoder.default(self, obj)
@@ -653,7 +663,7 @@ def __eq__(self, other):
653663

654664
return self.__dict__ == other.__dict__
655665

656-
async def rpc(self, msg):
666+
async def rpc(self, msg: dict[str, _RICH_JSON]) -> _JSON:
657667
result = await self.connection.rpc(msg, encoder=TypeEncoder)
658668
return result
659669

@@ -704,13 +714,13 @@ def _parse_nested_list_entry(expr, result_dict):
704714
return cls(**d)
705715
return None
706716

707-
def serialize(self):
717+
def serialize(self) -> dict[str, _JSON]:
708718
d = {}
709719
for attr, tgt in self._toSchema.items():
710720
d[tgt] = getattr(self, attr)
711721
return d
712722

713-
def to_json(self):
723+
def to_json(self) -> str:
714724
return json.dumps(self.serialize(), cls=TypeEncoder, sort_keys=True)
715725

716726
def __contains__(self, key):
@@ -917,8 +927,8 @@ def generate_definitions(schemas):
917927

918928

919929
def generate_facades(
920-
schemas: Dict[str, List[Schema]],
921-
) -> Dict[str, Dict[int, codegen.Capture]]:
930+
schemas: dict[str, list[Schema]],
931+
) -> dict[str, dict[int, codegen.Capture]]:
922932
captures = defaultdict(codegen.Capture)
923933

924934
# Build the Facade classes

pyproject.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ ignore = [
212212
[tool.pyright]
213213
# These are tentative
214214
# include = ["**/*.py"]
215-
pythonVersion = "3.8" # check no python > 3.8 features are used
216-
pythonPlatform = "All"
215+
pythonVersion = "3.10"
217216
typeCheckingMode = "strict"
217+
useLibraryCodeForTypes = true
218+
reportGeneralTypeIssues = true

0 commit comments

Comments
 (0)