1+ import numpy as np
2+ import logging
3+
4+ from bioimageio .spec import InvalidDescr , load_description
5+ from bioimageio .spec .model .v0_5 import ModelDescr
6+ import bioimageio .core .prediction as bi_pred
7+
8+ from skimage .filters import threshold_otsu
9+ from skimage .measure import label
10+ from skimage .morphology import closing , square
11+ from skimage .segmentation import clear_border
12+
13+ from server import ForeignToolClient
14+
15+ logger = logging .getLogger (__name__ )
16+
17+ # https://bioimage.io/#/?tags=affable-shark&id=10.5281%2Fzenodo.5764892
18+ MODEL_ID = "affable-shark"
19+ MODEL_DOI = "10.5281/zenodo.11092561"
20+
21+ def load_model ():
22+ loaded_description = load_description (MODEL_ID )
23+ if isinstance (loaded_description , InvalidDescr ):
24+ raise ValueError (f"Failed to load { MODEL_ID } " )
25+ elif not isinstance (loaded_description , ModelDescr ):
26+ raise ValueError ("This notebook expects a model 0.5 description" )
27+
28+ model = loaded_description
29+ example_model_id = model .id
30+ assert example_model_id is not None
31+
32+ try :
33+ descr = load_description (MODEL_ID )
34+ except InvalidDescr as e :
35+ logger .error (f"Invalid description: { e } " )
36+ return None
37+
38+ return descr
39+
40+ def predict (input_image , model ):
41+ out = bi_pred .predict (model = model , inputs = {'input0' : input_image }, skip_postprocessing = True , skip_preprocessing = True )
42+ return np .array (out .members ['output0' ].data [0 ])
43+
44+ def run (image_data , image_header ):
45+ model = load_model ()
46+
47+ logger .debug ("loaded model" )
48+
49+ # scaled image
50+ im = image_data .copy ()
51+ logger .debug (f"provided image of shape { im .shape } , type { im .dtype } " )
52+ # im = (image_data / np.iinfo(image_data.dtype).max).astype(np.float32)
53+
54+ pad_y = (64 - image_data .shape [0 ] % 64 ) % 64
55+ pad_x = (64 - image_data .shape [1 ] % 64 ) % 64
56+ # padded image
57+ im = np .pad (im , ((0 , pad_y ), (0 , pad_x )), mode = 'constant' , constant_values = 0 )
58+ logger .debug (f"padded image of shape { im .shape } , type { im .dtype } " )
59+
60+ # input image
61+ im = im .reshape ([1 ,1 ,im .shape [0 ],im .shape [1 ]])
62+ logger .debug (f"input image of shape { im .shape } , type { im .dtype } " )
63+
64+ # output image
65+ logger .debug ("running prediction" )
66+ res = predict (im , model )
67+ del im
68+ logger .debug (f"output image of shape { res .shape } , dtype { res .dtype } " )
69+
70+ # unpadded result
71+ res = res [:, :image_data .shape [0 ], :image_data .shape [1 ]]
72+ logger .debug (f"de-padded output image of shape { res .shape } , dtype { res .dtype } " )
73+
74+ # just the foreground probabilities, ignore boundaries
75+ res = res [0 ]
76+ logger .debug (f"using only fg prob of shape { res .shape } , dtype { res .dtype } " )
77+
78+ # threshold above certain prob
79+ thresh = threshold_otsu (res )
80+ logger .debug (f"threshold image shape { thresh .shape } , dtype { thresh .dtype } " )
81+ # make binary, with closing (remove small holes in fg with dilate then erode)
82+ bw = closing (res > thresh , square (3 ))
83+ logger .debug (f"binary image of shape { bw .shape } , type { bw .dtype } " )
84+
85+ # remove border cells
86+ # cleared = clear_border(bw)
87+ # labels = label(cleared)
88+
89+ # convert to labels
90+ labels = label (bw )
91+ logger .debug (f"labels of shape { labels .shape } , dtype { labels .dtype } " )
92+
93+ return labels
94+
95+
96+ def main ():
97+ client = ForeignToolClient (7878 , cb = run )
98+ client .receive_images ()
99+
100+ if __name__ == "__main__" :
101+ # init logging
102+ logging .root .setLevel (logging .DEBUG )
103+ stream_handler = logging .StreamHandler ()
104+ fmt = logging .Formatter (" [%(process)d|%(levelno)s] %(name)s::%(funcName)s: %(message)s" )
105+ stream_handler .setFormatter (fmt )
106+ logging .root .addHandler (stream_handler )
107+
108+ logger .debug ("Starting bimz.py" )
109+ main ()
0 commit comments