|
1 | | -from typing import Union, Any, Dict |
2 | | - |
3 | | -from polywrap_core import Wrapper, InvokeOptions, Invoker, InvocableResult, GetFileOptions |
4 | | -from polywrap_plugin import PluginModule |
5 | | -from polywrap_result import Result, Ok, Err |
| 1 | +from typing import Any, Dict, Union, cast, Generic |
| 2 | + |
| 3 | +from polywrap_core import ( |
| 4 | + GetFileOptions, |
| 5 | + InvocableResult, |
| 6 | + InvokeOptions, |
| 7 | + Invoker, |
| 8 | + Wrapper |
| 9 | +) |
6 | 10 | from polywrap_manifest import AnyWrapManifest |
7 | 11 | from polywrap_msgpack import msgpack_decode |
| 12 | +from polywrap_result import Err, Ok, Result |
| 13 | + |
| 14 | +from polywrap_plugin import PluginModule, TConfig, TResult |
8 | 15 |
|
9 | | -class PluginWrapper(Wrapper): |
| 16 | +class PluginWrapper(Wrapper, Generic[TConfig, TResult]): |
| 17 | + module: PluginModule[TConfig, TResult] |
10 | 18 | manifest: AnyWrapManifest |
11 | | - module: PluginModule |
12 | 19 |
|
13 | | - def __init__(self, manifest: AnyWrapManifest, module: PluginModule) -> None: |
14 | | - self.manifest = manifest |
| 20 | + def __init__(self, module: PluginModule[TConfig, TResult], manifest: AnyWrapManifest) -> None: |
15 | 21 | self.module = module |
| 22 | + self.manifest = manifest |
16 | 23 |
|
17 | 24 | async def invoke( |
18 | 25 | self, options: InvokeOptions, invoker: Invoker |
19 | 26 | ) -> Result[InvocableResult]: |
20 | | - |
21 | | - method = options.method |
22 | | - if not self.module.get_method(method): |
23 | | - return Err(Exception(f"PluginWrapper: method {method} not found")) |
24 | | - |
25 | 27 | env = options.env if options.env else {} |
26 | 28 | self.module.set_env(env) |
27 | 29 |
|
28 | | - decoded_args: Dict[str, Any] = options.args if options.args else {} |
| 30 | + decoded_args: Union[Dict[str, Any], bytes] = options.args if options.args else {} |
29 | 31 |
|
30 | 32 | if isinstance(decoded_args, bytes): |
31 | 33 | decoded_args = msgpack_decode(decoded_args) |
32 | 34 |
|
33 | | - result = self.module._wrap_invoke(method, decoded_args, invoker) |
| 35 | + result: Result[TResult] = await self.module._wrap_invoke(options.method, decoded_args, invoker) # type: ignore |
| 36 | + |
| 37 | + if result.is_err(): |
| 38 | + return cast(Err, result.err) |
| 39 | + |
| 40 | + return Ok(InvocableResult(result=result,encoded=False)) |
34 | 41 |
|
35 | | - if result.ok: |
36 | | - return Ok(InvocableResult(result=result,encoded=False)) |
37 | 42 |
|
38 | 43 | async def get_file(self, options: GetFileOptions) -> Result[Union[str, bytes]]: |
39 | 44 | return Err(Exception("client.get_file(..) is not implemented for plugins")) |
|
0 commit comments