88custom extension types defined by wrap standard
99"""
1010from enum import Enum
11- from typing import Any , Dict , List , Set , cast
11+ from typing import Any , Dict , List , Set , Tuple , cast
1212
1313import msgpack
14+ from msgpack .ext import ExtType
1415from msgpack .exceptions import UnpackValueError
1516
17+ from .generic_map import GenericMap
18+
1619
1720class ExtensionTypes (Enum ):
1821 """Wrap msgpack extension types."""
1922
2023 GENERIC_MAP = 1
2124
2225
23- def ext_hook (code : int , data : bytes ) -> Any :
26+ def encode_ext_hook (obj : Any ) -> ExtType :
27+ """Extension hook for extending the msgpack supported types.
28+
29+ Args:
30+ obj (Any): object to be encoded
31+
32+ Raises:
33+ TypeError: when given object is not supported
34+
35+ Returns:
36+ Tuple[int, bytes]: extension type code and payload
37+ """
38+ if isinstance (obj , GenericMap ):
39+ return ExtType (ExtensionTypes .GENERIC_MAP .value , msgpack_encode (obj ._map )) # type: ignore
40+ raise TypeError (f"Object of type { type (obj )} is not supported" )
41+
42+
43+ def decode_ext_hook (code : int , data : bytes ) -> Any :
2444 """Extension hook for extending the msgpack supported types.
2545
2646 Args:
@@ -34,7 +54,7 @@ def ext_hook(code: int, data: bytes) -> Any:
3454 Any: decoded object
3555 """
3656 if code == ExtensionTypes .GENERIC_MAP .value :
37- return msgpack_decode (data )
57+ return GenericMap ( msgpack_decode (data ) )
3858 raise UnpackValueError ("Invalid Extention type" )
3959
4060
@@ -50,6 +70,8 @@ def sanitize(value: Any) -> Any:
5070 Returns:
5171 Any: msgpack compatible sanitized value
5272 """
73+ if isinstance (value , GenericMap ):
74+ return cast (Any , value )
5375 if isinstance (value , dict ):
5476 dictionary : Dict [Any , Any ] = value
5577 for key , val in dictionary .items ():
@@ -59,7 +81,7 @@ def sanitize(value: Any) -> Any:
5981 array : List [Any ] = value
6082 return [sanitize (a ) for a in array ]
6183 if isinstance (value , tuple ):
62- array : List [Any ] = list (value ) # type: ignore partially unknown
84+ array : List [Any ] = list (cast ( Tuple [ Any ], value ))
6385 return sanitize (array )
6486 if isinstance (value , set ):
6587 set_val : Set [Any ] = value
@@ -87,7 +109,7 @@ def msgpack_encode(value: Any) -> bytes:
87109 bytes: encoded msgpack value
88110 """
89111 sanitized = sanitize (value )
90- return msgpack .packb (sanitized )
112+ return msgpack .packb (sanitized , default = encode_ext_hook , use_bin_type = True )
91113
92114
93115def msgpack_decode (val : bytes ) -> Any :
@@ -99,4 +121,4 @@ def msgpack_decode(val: bytes) -> Any:
99121 Returns:
100122 Any: python object
101123 """
102- return msgpack .unpackb (val , ext_hook = ext_hook )
124+ return msgpack .unpackb (val , ext_hook = decode_ext_hook )
0 commit comments