Skip to content

Commit 661ff10

Browse files
committed
Add pack/unpack encoding arguments and more Python3 support. Thanks to followers in Twitter!
1 parent 82c5fb7 commit 661ff10

5 files changed

Lines changed: 36 additions & 42 deletions

File tree

msgpackrpc/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,9 @@ class Client(session.Session):
77
Client is usaful for MessagePack RPC API.
88
"""
99

10-
def __init__(self, address, timeout=10, loop=None, builder=tcp):
10+
def __init__(self, address, timeout=10, loop=None, builder=tcp, reconnect_limit=5, pack_encoding='utf-8', unpack_encoding=None):
1111
loop = loop or Loop()
12-
session.Session.__init__(self, address, timeout, loop, builder)
12+
session.Session.__init__(self, address, timeout, loop, builder, reconnect_limit, pack_encoding, unpack_encoding)
1313

1414
if timeout:
1515
loop.attach_periodic_callback(self.step_timeout, 1000) # each 1s

msgpackrpc/server.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,15 @@ class Server(session.Session):
1212
Server is usaful for MessagePack RPC Server.
1313
"""
1414

15-
def __init__(self, dispatcher, loop=None, builder=tcp):
15+
def __init__(self, dispatcher, loop=None, builder=tcp, pack_encoding='utf-8', unpack_encoding=None):
1616
self._loop = loop or Loop()
1717
self._builder = builder
18+
self._encodings = (pack_encoding, unpack_encoding)
1819
self._listeners = []
1920
self._dispatcher = dispatcher
2021

2122
def listen(self, address):
22-
listener = self._builder.ServerTransport(address)
23+
listener = self._builder.ServerTransport(address, self._encodings)
2324
listener.listen(self)
2425
self._listeners.append(listener)
2526

@@ -41,7 +42,7 @@ def on_notify(self, method, param):
4142

4243
def dispatch(self, method, param, responder):
4344
try:
44-
if inPy3k:
45+
if inPy3k and not isinstance(method, str):
4546
method = method.decode("utf-8")
4647
if not hasattr(self._dispatcher, method):
4748
raise error.NoMethodError("'{0}' method not found".format(method))

msgpackrpc/session.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ class Session(object):
2121
result to the corresponding future.
2222
"""
2323

24-
def __init__(self, address, timeout, loop=None, builder=tcp):
24+
def __init__(self, address, timeout, loop=None, builder=tcp, reconnect_limit=5, pack_encoding='utf-8', unpack_encoding=None):
2525
"""\
2626
:param address: address of the server.
2727
:param loop: context object.
@@ -31,7 +31,7 @@ def __init__(self, address, timeout, loop=None, builder=tcp):
3131
self._loop = loop or Loop()
3232
self._address = address
3333
self._timeout = timeout
34-
self._transport = builder.build_transport(self)
34+
self._transport = builder.ClientTransport(self, self._address, reconnect_limit, encodings=(pack_encoding, unpack_encoding))
3535
self._generator = _NoSyncIDGenerator()
3636
self._request_table = {}
3737

@@ -72,12 +72,6 @@ def on_connect_failed(self, reason):
7272
Called by the transport layer.
7373
"""
7474

75-
def _iteritems(dic): # ugly!!!!!!
76-
if inPy3k:
77-
return dic.items()
78-
else:
79-
return dic.iteritems()
80-
8175
# set error for all requests
8276
#for msgid, future in self._request_table.iteritems():
8377
for msgid, future in _iteritems(self._request_table):
@@ -111,7 +105,7 @@ def on_timeout(self, msgid):
111105

112106
def step_timeout(self):
113107
timeouts = []
114-
for msgid, future in self._request_table.iteritems():
108+
for msgid, future in _iteritems(self._request_table):
115109
if future.step_timeout():
116110
timeouts.append(msgid)
117111

@@ -124,6 +118,11 @@ def step_timeout(self):
124118
future.set_error("Request timed out")
125119
self._loop.start()
126120

121+
def _iteritems(dic): # ugly!!!!!!
122+
if inPy3k:
123+
return dic.items()
124+
else:
125+
return dic.iteritems()
127126

128127
def _NoSyncIDGenerator():
129128
"""\

msgpackrpc/transport/tcp.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,15 +8,11 @@
88
from msgpackrpc import error
99

1010

11-
def build_transport(session, reconnect_limit=5):
12-
return ClientTransport(session, session.address, reconnect_limit)
13-
14-
1511
class BaseSocket(object):
16-
def __init__(self, stream):
12+
def __init__(self, stream, encodings):
1713
self._stream = stream
18-
self._packer = msgpack.Packer()
19-
self._unpacker = msgpack.Unpacker()
14+
self._packer = msgpack.Packer(encoding=encodings[0])
15+
self._unpacker = msgpack.Unpacker(encoding=encodings[1])
2016

2117
def close(self):
2218
self._stream.close()
@@ -55,8 +51,8 @@ def on_notify(self, method, param):
5551

5652

5753
class ClientSocket(BaseSocket):
58-
def __init__(self, stream, transport):
59-
BaseSocket.__init__(self, stream)
54+
def __init__(self, stream, transport, encodings):
55+
BaseSocket.__init__(self, stream, encodings)
6056
self._transport = transport
6157
self._stream.set_close_callback(self.on_close)
6258

@@ -78,9 +74,10 @@ def on_response(self, msgid, error, result):
7874

7975

8076
class ClientTransport(object):
81-
def __init__(self, session, address, reconnect_limit):
77+
def __init__(self, session, address, reconnect_limit, encodings=('utf-8', None)):
8278
self._session = session
8379
self._address = address
80+
self._encodings = encodings
8481
self._reconnect_limit = reconnect_limit;
8582

8683
self._connecting = 0
@@ -99,7 +96,7 @@ def send_message(self, message, callback=None):
9996

10097
def connect(self):
10198
stream = IOStream(self._address.socket(), io_loop=self._session._loop._ioloop)
102-
socket = ClientSocket(stream, self)
99+
socket = ClientSocket(stream, self, self._encodings)
103100
socket.connect();
104101

105102
def close(self):
@@ -134,8 +131,8 @@ def on_close(self, sock):
134131

135132

136133
class ServerSocket(BaseSocket):
137-
def __init__(self, stream, transport):
138-
BaseSocket.__init__(self, stream)
134+
def __init__(self, stream, transport, encodings):
135+
BaseSocket.__init__(self, stream, encodings)
139136
self._transport = transport
140137
self._stream.read_until_close(self.on_read, self.on_read)
141138

@@ -150,21 +147,23 @@ def on_notify(self, method, param):
150147

151148

152149
class MessagePackServer(netutil.TCPServer):
153-
def __init__(self, transport, io_loop=None):
150+
def __init__(self, transport, io_loop=None, encodings=None):
154151
self._transport = transport
152+
self._encodings = encodings
155153
netutil.TCPServer.__init__(self, io_loop=io_loop)
156154

157155
def handle_stream(self, stream, address):
158-
ServerSocket(stream, self._transport)
156+
ServerSocket(stream, self._transport, self._encodings)
159157

160158

161159
class ServerTransport(object):
162-
def __init__(self, address):
160+
def __init__(self, address, encodings=('utf-8', None)):
163161
self._address = address;
162+
self._encodings = encodings
164163

165164
def listen(self, server):
166165
self._server = server;
167-
self._mp_server = MessagePackServer(self, io_loop=self._server._loop._ioloop)
166+
self._mp_server = MessagePackServer(self, io_loop=self._server._loop._ioloop, encodings=self._encodings)
168167
self._mp_server.listen(self._address.port)
169168

170169
def close(self):

test/test_msgpackrpc.py

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ def _start_server(server):
2929
self._thread = threading.Thread(target=_start_server, args=(self._server,))
3030
self._thread.start()
3131

32-
self._client = msgpackrpc.Client(self._address)
32+
self._client = msgpackrpc.Client(self._address, unpack_encoding='utf-8')
3333
return self._client;
3434

3535
def tearDown(self):
@@ -40,8 +40,6 @@ def tearDown(self):
4040
def test_call(self):
4141
client = self.setup_env();
4242
result = client.call('hello')
43-
if inPy3k:
44-
result = result.decode("utf-8")
4543
self.assertEqual(result, "world", "'hello' result is incorrect")
4644

4745
result = client.call('sum', 1, 2)
@@ -55,11 +53,7 @@ def test_call_async(self):
5553
feture1.join()
5654
feture2.join()
5755

58-
if inPy3k:
59-
result = feture1.result.decode("utf-8")
60-
else:
61-
result = feture1.result
62-
self.assertEqual(result, "world", "'hello' result is incorrect in call_async")
56+
self.assertEqual(feture1.result, "world", "'hello' result is incorrect in call_async")
6357
self.assertEqual(feture2.result, 3, "'sum' result is incorrect in call_async")
6458

6559
def test_notify(self):
@@ -70,16 +64,17 @@ def test_notify(self):
7064
client.notify('sum', 1, 2)
7165
except:
7266
result = False
73-
self.assert_(result)
67+
self.assertTrue(result)
7468

7569
def test_unknown_method(self):
7670
client = self.setup_env();
7771
self.assertRaises(error.RPCError, lambda: client.call('unknown', True))
7872
try:
7973
client.call('unknown', True)
80-
self.assert_(False)
74+
self.assertTrue(False)
8175
except error.RPCError as e:
82-
self.assertEqual(e.message, "'unknown' method not found", "Error message mismatched")
76+
message = e.args[0]
77+
self.assertEqual(message, "'unknown' method not found", "Error message mismatched")
8378

8479

8580
if __name__ == '__main__':

0 commit comments

Comments
 (0)