Skip to content

Commit c683fdf

Browse files
committed
Write server module
1 parent 2ac4f71 commit c683fdf

2 files changed

Lines changed: 193 additions & 0 deletions

File tree

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
# -*- coding: utf-8 -*-
2+
3+
__all__ = ["server"]
Lines changed: 190 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,190 @@
1+
import zmq
2+
import numpy as np
3+
import logging
4+
5+
HELLO = "Hello"
6+
ACK = "Acknowledge"
7+
DENY = "Deny"
8+
CANCEL = "Cancel"
9+
10+
class ForeignToolError(Exception):
11+
pass
12+
13+
def _receive_ack(socket, logger, subject=None):
14+
ack = socket.recv_string()
15+
if ack == ACK:
16+
if subject:
17+
logger.debug(f"received ack for {subject}")
18+
else:
19+
logger.debug("received ack")
20+
return True
21+
elif ack == DENY:
22+
raise ForeignToolError("denied, aborting")
23+
elif ack == CANCEL:
24+
raise ForeignToolError("canceled, aborting")
25+
else:
26+
raise ForeignToolError("unexpected response", ack)
27+
28+
def _send_ack(socket, logger, subject):
29+
if subject:
30+
logger.debug(f"sending ack for {subject}")
31+
else:
32+
logger.debug("sending ack")
33+
socket.send_string(ACK)
34+
35+
class ForeignToolServer(object):
36+
def __init__(self, port, domain='*', protocol='tcp', wait_for_handshake=True):
37+
"""
38+
Launch a server on the given port.
39+
"""
40+
self._logger = logging.getLogger(f"{__name__} [server]")
41+
42+
self._context = zmq.Context()
43+
self._server_socket = self._context.socket(zmq.PAIR)
44+
self._server_socket.bind(f"{protocol}://{domain}:{port}")
45+
self._logger.info("launched on", self._server_socket.getsockopt(zmq.LAST_ENDPOINT))
46+
47+
if wait_for_handshake:
48+
self.wait_for_handshake()
49+
50+
def wait_for_handshake(self):
51+
self._logger.debug("waiting for handshake from cllient")
52+
client_hello = self._server_socket.recv_string()
53+
if client_hello == HELLO:
54+
self._logger.debug("received correct handshake")
55+
_send_ack(self._server_socket, self._logger, subject="handshake")
56+
else:
57+
self._logger.debug("received incorrect handshake", client_hello)
58+
self._logger.debug("sending deny")
59+
self._server_socket.send_string(DENY)
60+
raise ForeignToolError("server received incorrect handshake")
61+
62+
def _serve_image(self, image_data):
63+
"""
64+
Serve an image to the client.
65+
"""
66+
header = np.lib.format.header_data_from_array_1_0(image_data)
67+
68+
self._logger.debug("sending header", header, "waiting for acknowledgement")
69+
self._server_socket.send_json(header)
70+
71+
ack = _receive_ack(self._server_socket, self._logger, subject="header")
72+
73+
self._logger.debug("sending image data", image_data.shape, "waiting for acknowledgement")
74+
75+
self._server_socket.send(image_data, copy=False)
76+
77+
ack = _receive_ack(self._server_socket, self._logger, subject="image data")
78+
79+
labels_header = self._server_socket.recv_json()
80+
81+
ack = _send_ack(self._server_socket, self._logger, subject="return header")
82+
83+
label_bytes = self._server_socket.recv(copy=False)
84+
85+
self._logger.debug("received label byte data")
86+
87+
self._logger.debug("parsing label data")
88+
labels = np.frombuffer(label_bytes, dtype=labels_header['descr'])
89+
labels.shape = labels_header['shape']
90+
self._logger.debug("parse label data", labels.shape)
91+
92+
_send_ack(self._server_socket, self._logger, subject="return data")
93+
94+
return labels
95+
96+
def serve_one_image(self, image_data):
97+
"""
98+
Serve an image to the client.
99+
"""
100+
return self._serve_image(image_data)
101+
102+
class ForeignToolClient(object):
103+
def __init__(self, port, domain='*', protocol='tcp', do_handshake=True, cb=None):
104+
"""
105+
Connect to a server on the given port.
106+
"""
107+
self._logger = logging.getLogger(f"{__name__} [client]")
108+
109+
self._context = zmq.Context()
110+
self._client_socket = self._context.socket(zmq.PAIR)
111+
self._client_socket.connect(f"{protocol}://{domain}:{port}")
112+
self._logger.info("connected to", self._client_socket.getsockopt(zmq.LAST_ENDPOINT))
113+
114+
if cb:
115+
self.register_cb(cb)
116+
117+
if do_handshake:
118+
self.do_handshake()
119+
120+
def do_handshake(self):
121+
"""
122+
Handshake with the server.
123+
"""
124+
self.client_socket.send_string(HELLO)
125+
response = _receive_ack(self.client_socket, self.logger, subject="handshake")
126+
127+
def register_cb(self, cb):
128+
"""
129+
Register a callback to be executed on the server.
130+
Must be run before receeive_image
131+
"""
132+
self._cb = cb
133+
134+
def _execute_cb(self, im, header):
135+
"""
136+
Execute the callback on the server.
137+
"""
138+
return self._cb(im, header)
139+
140+
def _receive_image(self):
141+
"""
142+
Receive an image from the server.
143+
"""
144+
header = self._client_socket.recv_json()
145+
self._logger.debug("received header", header)
146+
147+
_send_ack(self._client_socket, self._logger, subject="header")
148+
149+
im_bytes = self._client_socket.recv(copy=False)
150+
self._logger.debug("received image bytes")
151+
152+
self._logger.debug("parsing image data")
153+
buf = memoryview(im_bytes)
154+
im = np.frombuffer(buf, dtype=header['descr'])
155+
im = (im * 255).astype(np.uint8)
156+
im.shape = header['shape']
157+
self._logger.debug("parsed image data", im.shape)
158+
159+
_send_ack(self._client_socket, self._logger, subject="image data")
160+
161+
self._logger.debug("executing callback")
162+
return_data = self._execute_cb(im, header)
163+
self._logger.debug("executed callback")
164+
165+
return_header = np.lib.format.header_data_from_array_1_0(return_data)
166+
self._logger.debug("returning header", return_header)
167+
self._client_socket.send_json(return_header)
168+
169+
ack = _receive_ack(self._client_socket, self._logger, subject="return header")
170+
171+
self._logger.debug("returning data")
172+
self._client_socket.send(return_data, copy=False)
173+
174+
ack = _receive_ack(self._client_socket, self._logger, subject="return data")
175+
176+
def receive_one_image(self):
177+
"""
178+
Receive a single image from the server.
179+
"""
180+
self._receive_image()
181+
182+
def receive_images(self):
183+
"""
184+
Receive images from the server.
185+
"""
186+
while True:
187+
try:
188+
self._receive_image()
189+
except ForeignToolError:
190+
break

0 commit comments

Comments
 (0)