Skip to content

Commit fed1030

Browse files
committed
Feat: Consolidate _RpcBase._call and change unnecssary factory functions into class initializers
1 parent 4286aed commit fed1030

1 file changed

Lines changed: 22 additions & 39 deletions

File tree

python/twinleaf/__init__.py

Lines changed: 22 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ def __init__(self, name, device: Device):
150150
self.__name__ = name
151151
self._device = device
152152

153+
def __call__(self) -> dict[str, _rpc_type]:
154+
return self._survey()
155+
153156
def _survey(self) -> dict[str, _rpc_type]:
154157
""" Recursively collect all readable RPC values in this subtree """
155158
results = {}
@@ -179,60 +182,42 @@ def __init__(self, pyrpc: _twinleaf._Rpc, device: Device):
179182
case '' if self._size_bytes == 0: self._data_type = None
180183
case other: self._data_type = bytes
181184

182-
def _call_with_arg(self, arg: _rpc_type=None) -> _rpc_type:
185+
def _call(self, arg: _rpc_type=None) -> _rpc_type:
183186
match self._data_type:
184187
case t if t is int:
185188
return self._device._rpc_int(self.__name__, self._size_bytes, self._signed, arg)
186189
case t if t is float:
187190
return self._device._rpc_float(self.__name__, self._size_bytes, arg)
188191
case t if t is str:
192+
if arg is None: arg = ''
189193
return self._device._rpc(self.__name__, arg.encode()).decode()
190194
case t if t is bytes:
195+
if arg is None: arg = b''
191196
return self._device._rpc(self.__name__, arg)
192197
case None:
193198
return self._device._rpc(self.__name__, b'')
194199
case other:
195200
raise TypeError(f"Invalid RPC type {other}, RPC types must be {_rpc_type}")
196201

197-
def _call(self) -> _rpc_type:
198-
match self._data_type:
199-
case t if t is int:
200-
return self._device._rpc_int(self.__name__, self._size_bytes, self._signed)
201-
case t if t is float:
202-
return self._device._rpc_float(self.__name__, self._size_bytes)
203-
case t if t is str:
204-
return self._device._rpc(self.__name__, b'').decode()
205-
case t if t is bytes or _ is None:
206-
return self._device._rpc(self.__name__, b'')
207-
case other:
208-
raise TypeError(f"Invalid RPC type {other}, RPC types must be {_rpc_type}")
209-
210-
class _RpcSurveyBase(_RpcNode):
211-
""" Internal class for RPC surveys """
212-
def __init__(self, name: str, device: Device):
213-
super().__init__(name, device)
214-
215-
def __call__(self):
216-
return self._survey()
217-
218202
def _Rpc(pyrpc: _twinleaf._Rpc, device: Device) -> _RpcNode:
219203
""" Factory function that creates an RPC with appropriate __call__ signature """
220-
if pyrpc.writable:
221-
def __call__(self, arg: _rpc_type=None) -> _rpc_type:
222-
if arg is None:
223-
return self._call()
224-
else:
225-
return self._call_with_arg(arg)
204+
base_rpc = _RpcBase(pyrpc, device)
205+
if base_rpc._writable and base_rpc._data_type is not None:
206+
def __call__(self, arg=None):
207+
return self._call(arg)
208+
__call__.__annotations__ |= { 'arg': base_rpc._data_type | None }
209+
__call__.__annotations__ |= { 'return': base_rpc._data_type }
226210
else:
227211
def __call__(self) -> _rpc_type:
228212
return self._call()
213+
__call__.__annotations__ |= { 'return': base_rpc._data_type }
229214

230215
cls = type('Rpc', (_RpcBase,), {'__call__': __call__})
231216
return cls(pyrpc, device)
232217

233218
def _RpcSurvey(name: str, device: Device) -> _RpcNode:
234219
""" Factory function that creates an RPC survey """
235-
cls = type('Survey', (_RpcSurveyBase,), {})
220+
cls = type('Survey', (_RpcNode,), {})
236221
return cls(name, device)
237222

238223
# Samples classes
@@ -244,20 +229,18 @@ def __init__(self, device: Device, name: str, stream: str, columns: list[str]):
244229
self._stream = stream
245230
self._columns = columns
246231

247-
class _SamplesDictBase(_SamplesBase):
232+
class _SamplesDict(_SamplesBase):
248233
""" Returns samples as dict keyed by stream_id """
234+
def __init__(self, device: Device, name: str, stream: str="", columns: list[str] | None=None):
235+
super().__init__(device, name, stream, columns if columns is not None else [] )
236+
249237
def __call__(self, n: int=1, **kwargs):
250238
return self._device._samples_dict(n, self._stream, self._columns, **kwargs)
251239

252-
class _SamplesListBase(_SamplesBase):
240+
class _SamplesList(_SamplesBase):
253241
""" Returns samples as list for single stream """
242+
def __init__(self, device: Device, name: str, stream: str="", columns: list[str] | None=None):
243+
super().__init__(device, name, stream, columns if columns is not None else [] )
244+
254245
def __call__(self, n: int=1, **kwargs):
255246
return self._device._samples_list(n, self._stream, self._columns, **kwargs)
256-
257-
def _SamplesDict(device: Device, name: str, stream: str="", columns: list[str] | None=None):
258-
""" Factory function that creates a sample dict """
259-
return _SamplesDictBase(device, name, stream, columns if columns is not None else [])
260-
261-
def _SamplesList(device: Device, name: str, stream: str="", columns: list[str] | None=None):
262-
""" Factory function that creates a sample list """
263-
return _SamplesListBase(device, name, stream, columns if columns is not None else [])

0 commit comments

Comments
 (0)