Skip to content

Commit bd881d2

Browse files
committed
Add foreign env runner
1 parent c683fdf commit bd881d2

4 files changed

Lines changed: 282 additions & 63 deletions

File tree

CP5/active_plugins/cpforeign/server.py

Lines changed: 17 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -42,19 +42,19 @@ def __init__(self, port, domain='*', protocol='tcp', wait_for_handshake=True):
4242
self._context = zmq.Context()
4343
self._server_socket = self._context.socket(zmq.PAIR)
4444
self._server_socket.bind(f"{protocol}://{domain}:{port}")
45-
self._logger.info("launched on", self._server_socket.getsockopt(zmq.LAST_ENDPOINT))
45+
self._logger.info(f"launched on {self._server_socket.getsockopt(zmq.LAST_ENDPOINT)}")
4646

4747
if wait_for_handshake:
4848
self.wait_for_handshake()
4949

5050
def wait_for_handshake(self):
51-
self._logger.debug("waiting for handshake from cllient")
51+
self._logger.debug("waiting for handshake from client")
5252
client_hello = self._server_socket.recv_string()
5353
if client_hello == HELLO:
5454
self._logger.debug("received correct handshake")
5555
_send_ack(self._server_socket, self._logger, subject="handshake")
5656
else:
57-
self._logger.debug("received incorrect handshake", client_hello)
57+
self._logger.debug(f"received incorrect handshake {client_hello}")
5858
self._logger.debug("sending deny")
5959
self._server_socket.send_string(DENY)
6060
raise ForeignToolError("server received incorrect handshake")
@@ -65,12 +65,12 @@ def _serve_image(self, image_data):
6565
"""
6666
header = np.lib.format.header_data_from_array_1_0(image_data)
6767

68-
self._logger.debug("sending header", header, "waiting for acknowledgement")
68+
self._logger.debug(f"sending header {header} waiting for acknowledgement")
6969
self._server_socket.send_json(header)
7070

7171
ack = _receive_ack(self._server_socket, self._logger, subject="header")
7272

73-
self._logger.debug("sending image data", image_data.shape, "waiting for acknowledgement")
73+
self._logger.debug(f"sending image data {image_data.shape} waiting for acknowledgement")
7474

7575
self._server_socket.send(image_data, copy=False)
7676

@@ -87,7 +87,7 @@ def _serve_image(self, image_data):
8787
self._logger.debug("parsing label data")
8888
labels = np.frombuffer(label_bytes, dtype=labels_header['descr'])
8989
labels.shape = labels_header['shape']
90-
self._logger.debug("parse label data", labels.shape)
90+
self._logger.debug(f"parse label data of shape {labels.shape}")
9191

9292
_send_ack(self._server_socket, self._logger, subject="return data")
9393

@@ -100,7 +100,7 @@ def serve_one_image(self, image_data):
100100
return self._serve_image(image_data)
101101

102102
class ForeignToolClient(object):
103-
def __init__(self, port, domain='*', protocol='tcp', do_handshake=True, cb=None):
103+
def __init__(self, port, domain='localhost', protocol='tcp', do_handshake=True, cb=None):
104104
"""
105105
Connect to a server on the given port.
106106
"""
@@ -109,20 +109,20 @@ def __init__(self, port, domain='*', protocol='tcp', do_handshake=True, cb=None)
109109
self._context = zmq.Context()
110110
self._client_socket = self._context.socket(zmq.PAIR)
111111
self._client_socket.connect(f"{protocol}://{domain}:{port}")
112-
self._logger.info("connected to", self._client_socket.getsockopt(zmq.LAST_ENDPOINT))
112+
self._logger.info(f"connected to {self._client_socket.getsockopt(zmq.LAST_ENDPOINT)}")
113113

114114
if cb:
115115
self.register_cb(cb)
116116

117117
if do_handshake:
118118
self.do_handshake()
119119

120-
def do_handshake(self):
121-
"""
122-
Handshake with the server.
123-
"""
124-
self.client_socket.send_string(HELLO)
125-
response = _receive_ack(self.client_socket, self.logger, subject="handshake")
120+
def do_handshake(self):
121+
"""
122+
Handshake with the server.
123+
"""
124+
self._client_socket.send_string(HELLO)
125+
response = _receive_ack(self._client_socket, self._logger, subject="handshake")
126126

127127
def register_cb(self, cb):
128128
"""
@@ -142,7 +142,7 @@ def _receive_image(self):
142142
Receive an image from the server.
143143
"""
144144
header = self._client_socket.recv_json()
145-
self._logger.debug("received header", header)
145+
self._logger.debug(f"received header {header}")
146146

147147
_send_ack(self._client_socket, self._logger, subject="header")
148148

@@ -152,9 +152,8 @@ def _receive_image(self):
152152
self._logger.debug("parsing image data")
153153
buf = memoryview(im_bytes)
154154
im = np.frombuffer(buf, dtype=header['descr'])
155-
im = (im * 255).astype(np.uint8)
156155
im.shape = header['shape']
157-
self._logger.debug("parsed image data", im.shape)
156+
self._logger.debug(f"parsed image data {im.shape}")
158157

159158
_send_ack(self._client_socket, self._logger, subject="image data")
160159

@@ -163,7 +162,7 @@ def _receive_image(self):
163162
self._logger.debug("executed callback")
164163

165164
return_header = np.lib.format.header_data_from_array_1_0(return_data)
166-
self._logger.debug("returning header", return_header)
165+
self._logger.debug(f"returning header {return_header}")
167166
self._client_socket.send_json(return_header)
168167

169168
ack = _receive_ack(self._client_socket, self._logger, subject="return header")
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
import logging
2+
import numpy as np
3+
import skimage as ski
4+
import scipy as sp
5+
6+
from server import ForeignToolClient
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
def run(image_data, image_header):
12+
im = (image_data * 255).astype(np.uint8)
13+
14+
markers = np.zeros_like(im, dtype=np.uint8)
15+
IDK = 0
16+
BG = 1
17+
FG = 2
18+
markers[im < 30] = BG
19+
markers[im > 50] = FG
20+
# rest = IDK
21+
22+
elevation_map = ski.filters.sobel(im)
23+
segmentation = ski.segmentation.watershed(elevation_map, markers)
24+
segmentation = sp.ndimage.binary_fill_holes(segmentation - 1)
25+
26+
labels, _ = sp.ndimage.label(segmentation)
27+
28+
# remove small objects
29+
sizes = np.bincount(labels.ravel())
30+
mask_sizes = sizes > 20
31+
mask_sizes[0] = 0
32+
segmentation = mask_sizes[labels]
33+
34+
labels, _ = sp.ndimage.label(segmentation)
35+
36+
ski.io.imsave("/Users/ngogober/Desktop/thresh.tif", labels)
37+
38+
return labels
39+
40+
41+
def main():
42+
client = ForeignToolClient(7878, cb=run)
43+
client.receive_images()
44+
45+
if __name__ == "__main__":
46+
# init logging
47+
logging.root.setLevel(logging.DEBUG)
48+
stream_handler = logging.StreamHandler()
49+
fmt = logging.Formatter(" [%(process)d|%(levelno)s] %(name)s::%(funcName)s: %(message)s")
50+
stream_handler.setFormatter(fmt)
51+
logging.root.addHandler(stream_handler)
52+
53+
logger.debug("Starting thresh.py")
54+
main()

CP5/active_plugins/cpforeign/zmq_server.ipynb

Lines changed: 39 additions & 45 deletions
Large diffs are not rendered by default.
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
import shlex
2+
import sys
3+
import os
4+
import re
5+
import subprocess
6+
import threading
7+
import logging
8+
9+
from cellprofiler_core.module.image_segmentation import ImageSegmentation
10+
from cellprofiler_core.setting.text import Text
11+
from cellprofiler_core.setting.text import Filename
12+
from cellprofiler_core.setting.text import Integer
13+
from cellprofiler_core.object import Objects
14+
15+
from cpforeign.server import ForeignToolServer
16+
17+
LOGGER = logging.getLogger(__name__)
18+
19+
HELLO = "Hello"
20+
ACK = "Acknowledge"
21+
DENIED = "Denied"
22+
23+
__doc__ = """\
24+
RunForeignEnv
25+
============
26+
27+
**RunForeign** runs a foreign tool, in a foreign (conda) environment, via sockets.
28+
29+
30+
Assumes there is a client up and running.
31+
32+
|
33+
34+
============ ============ ===============
35+
Supports 2D? Supports 3D? Respects masks?
36+
============ ============ ===============
37+
YES NO YES
38+
============ ============ ===============
39+
40+
"""
41+
42+
def _run_logger(workR):
43+
# this thread shuts itself down by reading from worker's stdout
44+
# which either reads content from stdout or blocks until it can do so
45+
# when the worker is shut down, empty byte string is returned continuously
46+
# which evaluates as None so the break is hit
47+
# I don't really like this approach; we should just shut it down with the other
48+
# threads explicitly
49+
while True:
50+
try:
51+
print('reading')
52+
line = workR.stdout.readline()
53+
if (type(line) == bytes):
54+
line = line.decode("utf-8")
55+
if not line:
56+
break
57+
log_msg_match = re.match(fr"{workR.pid}\|(10|20|30|40|50)\|(.*)", line)
58+
if log_msg_match:
59+
levelno = int(log_msg_match.group(1))
60+
msg = log_msg_match.group(2)
61+
else:
62+
levelno = 20
63+
msg = line
64+
65+
LOGGER.log(levelno, "\n\r [Worker (%d)] %s", workR.pid, msg.rstrip())
66+
67+
except Exception as e:
68+
LOGGER.exception(e)
69+
break
70+
71+
72+
class RunForeignEnv(ImageSegmentation):
73+
category = "Object Processing"
74+
75+
module_name = "RunForeignEnv"
76+
77+
variable_revision_number = 1
78+
79+
def create_settings(self):
80+
super().create_settings()
81+
82+
self._server = None
83+
self._client_launched = False
84+
85+
self.server_port = Integer(
86+
text="Server port number",
87+
value=7878,
88+
minval=0,
89+
doc="""\
90+
The port number which the server is listening on. The server must be launched manually first.
91+
""",
92+
)
93+
94+
self.env_name = Text(text="Conda environment name", value="foreign-thresh")
95+
96+
self.algo_path = Filename(text="Algorithm path", value="/Users/ngogober/Developer/CellProfiler/CellProfiler-plugins/CP5/active_plugins/cpforeign/thresh.py")
97+
98+
def settings(self):
99+
return super().settings() + [self.server_port, self.env_name, self.algo_path]
100+
101+
# ImageSegmentation defines this so we have to overide it
102+
def visible_settings(self):
103+
return self.settings()
104+
105+
# ImageSegmentation defines this so we have to overide it
106+
def volumetric(self):
107+
return False
108+
109+
def prepare_run(self, workspace):
110+
111+
LOGGER.debug(">>> Preparing run")
112+
if not self._server:
113+
LOGGER.debug(">>> Initializing server")
114+
self._server = ForeignToolServer(self.server_port.value, wait_for_handshake=False)
115+
116+
if not self._client_launched:
117+
LOGGER.debug(">>> Launching client")
118+
command = f"conda run --no-capture-output -n {self.env_name.value} python {self.algo_path.value}"
119+
args = shlex.split(command)
120+
env = os.environ.copy()
121+
env["PYTHONUNBUFFERED"] = "1"
122+
self._client_proc = subprocess.Popen(args, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stdout, bufsize=1, universal_newlines=True, env=env, close_fds=False)
123+
#self._client_thread = threading.Thread(target=_run_logger, args=(self._client_proc,), name="foreign client stdout logger thread")
124+
#self._client_thread.start()
125+
126+
self._client_launched = True
127+
self._server.wait_for_handshake()
128+
129+
return True
130+
131+
def post_run(self, workspace):
132+
if self._client_launched:
133+
LOGGER.debug(">>> Shuttding down client")
134+
#self._client_thread.join()
135+
self._client_proc.terminate()
136+
137+
def run(self, workspace):
138+
# TODO: is this supposed to not run in test mode? because it doesn't...
139+
self.prepare_run(workspace)
140+
141+
x_name = self.x_name.value
142+
143+
y_name = self.y_name.value
144+
145+
images = workspace.image_set
146+
147+
x = images.get_image(x_name)
148+
149+
dimensions = x.dimensions
150+
151+
x_data = x.pixel_data
152+
153+
y_data = self._server.serve_one_image(x_data)
154+
155+
y = Objects()
156+
157+
y.segmented = y_data
158+
159+
y.parent_image = x.parent_image
160+
161+
objects = workspace.object_set
162+
163+
objects.add_objects(y, y_name)
164+
165+
self.add_measurements(workspace)
166+
167+
if self.show_window:
168+
workspace.display_data.x_data = x_data
169+
170+
workspace.display_data.y_data = y_data
171+
172+
workspace.display_data.dimensions = dimensions

0 commit comments

Comments
 (0)