Skip to content

Commit dd3140e

Browse files
committed
support cp4, explicitly support omnipose
1 parent e386b10 commit dd3140e

1 file changed

Lines changed: 108 additions & 65 deletions

File tree

active_plugins/runcellpose.py

Lines changed: 108 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -50,27 +50,18 @@
5050
**RunCellpose** uses a pre-trained machine learning model (Cellpose) to detect cells or nuclei in an image.
5151
5252
This module is useful for automating simple segmentation tasks in CellProfiler.
53-
The module accepts greyscale input images and produces an object set. Probabilities can also be captured as an image.
53+
The module accepts greyscale input images and produces an object set.
54+
Probabilities can also be captured as an image.
5455
55-
Loading in a model will take slightly longer the first time you run it each session. When evaluating
56-
performance you may want to consider the time taken to predict subsequent images.
56+
Loading in a model will take slightly longer the first time you run it each session.
57+
When evaluating performance you may want to consider the time taken to predict subsequent images.
5758
58-
This module now also supports Ominpose. Omnipose builds on Cellpose, for the purpose of **RunCellpose** it adds 2 additional
59-
features: additional models; bact-omni and cyto2-omni which were trained using the Omnipose architechture, and bact
60-
and the mask reconstruction algorithm for Omnipose that was created to solve over-segemnation of large cells; useful for bacterial cells,
61-
but can be used for other arbitrary and anisotropic shapes. You can mix and match Omnipose models with Cellpose style masking or vice versa.
59+
This module is compatible with Omnipose, Cellpose 2, Cellpose 3, and Cellpose-SAM (4).
6260
63-
The module is compatible with Cellpose 1.0.2 >= 2.3.2. From the old version of the module the 'cells' model corresponds to 'cyto2' model.
61+
You can run this module using Cellpose installed to the same Python environment as CellProfiler.
62+
See our documentation at https://plugins.cellprofiler.org/runcellpose.html for more information on installation.
6463
65-
You can run this module using Cellpose installed to the same Python environment as CellProfiler. Alternatively, you can
66-
run this module using Cellpose in a Docker that the module will automatically download for you so you do not have to perform
67-
any installation yourself.
68-
69-
To install Cellpose in your Python environment:
70-
You'll want to run `pip install cellpose==2.3.2` on your CellProfiler Python environment to setup Cellpose. If you have an older version of Cellpose
71-
run 'python -m pip install --force-reinstall -v cellpose==2.3.2'.
72-
73-
To use Omnipose models, and mask reconstruction method you'll want to install Omnipose 'pip install omnipose' and Cellpose version 1.0.2 'pip install cellpose==1.0.2'.
64+
Alternatively, you can run this module using Cellpose in a Docker that the module will automatically download for you so you do not have to perform any installation yourself.
7465
7566
On the first time loading into CellProfiler, Cellpose will need to download some model files from the internet. This
7667
may take some time. If you want to use a GPU to run the model, you'll need a compatible version of PyTorch and a
@@ -102,7 +93,7 @@
10293

10394
DENOISER_NAMES = ['denoise_cyto3', 'deblur_cyto3', 'upsample_cyto3',
10495
'denoise_nuclei', 'deblur_nuclei', 'upsample_nuclei']
105-
# Only these models support size scaling
96+
# Only these models support size scaling for v2/v3
10697
SIZED_MODELS = {"cyto3", "cyto2", "cyto", "nuclei"}
10798

10899
def get_custom_model_vars(self):
@@ -155,9 +146,9 @@ def create_settings(self):
155146
)
156147
self.cellpose_version = Choice(
157148
text="Select Cellpose version",
158-
choices=['v2', 'v3', 'v4'],
149+
choices=['omnipose', 'v2', 'v3', 'v4'],
159150
value='v3',
160-
doc="Select the version of Cellpose you want to use. Note that v2 is compatible with either v1 or v2.")
151+
doc="Select the version of Cellpose you want to use.")
161152

162153
self.docker_image_v2 = Choice(
163154
text="Select Cellpose docker image",
@@ -578,6 +569,14 @@ def validate_module(self, pipeline):
578569
% model_path, self.model_file_name,
579570
)
580571

572+
def cleanup(self):
573+
from torch import cuda
574+
# Try to clear some GPU memory for other worker processes.
575+
try:
576+
cuda.empty_cache()
577+
except Exception as e:
578+
print(f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. {e}")
579+
581580
def run(self, workspace):
582581
x_name = self.x_name.value
583582
y_name = self.y_name.value
@@ -586,15 +585,20 @@ def run(self, workspace):
586585
dimensions = x.dimensions
587586
x_data = x.pixel_data
588587

588+
if self.cellpose_version.value == 'omnipose':
589+
self.mode = self.mode_v2
590+
self.denoise.value = False # Denoising only supported in v3
589591
if self.cellpose_version.value == 'v2':
590592
self.mode = self.mode_v2
591593
self.docker_image = self.docker_image_v2
594+
self.denoise.value = False # Denoising only supported in v3
592595
elif self.cellpose_version.value == 'v3':
593596
self.mode = self.mode_v3
594597
self.docker_image = self.docker_image_v3
595598
elif self.cellpose_version.value == 'v4':
596599
self.mode = self.mode_v4
597600
self.docker_image = self.docker_image_v4
601+
self.denoise.value = False # Denoising only supported in v3
598602

599603
if self.rescale.value:
600604
rescale_x = x_data.copy()
@@ -639,41 +643,45 @@ def run(self, workspace):
639643
from torch import cuda
640644
cuda.set_per_process_memory_fraction(self.manual_GPU_memory_share.value)
641645

646+
if self.cellpose_version.value == 'omnipose':
647+
assert int(self.cellpose_ver[0])<2, "Cellpose version selected in RunCellpose module doesn't match version in Python"
648+
assert float(self.cellpose_ver[0:3]) >= 0.6, "Cellpose v1/omnipose requires Cellpose >= 0.6"
649+
if self.mode.value != 'custom':
650+
model = models.Cellpose(model_type= self.mode.value,
651+
gpu=self.use_gpu.value)
652+
else:
653+
model_file, model_directory, model_path = get_custom_model_vars(self)
654+
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
655+
try:
656+
y_data, flows, *_ = model.eval(
657+
x_data,
658+
channels=channels,
659+
diameter=diam,
660+
net_avg=self.use_averaging.value,
661+
do_3D=self.do_3D.value,
662+
anisotropy=anisotropy,
663+
flow_threshold=self.flow_threshold.value,
664+
cellprob_threshold=self.cellprob_threshold.value,
665+
stitch_threshold=self.stitch_threshold.value, # is ignored if do_3D=True
666+
min_size=self.min_size.value,
667+
omni=self.omni.value,
668+
invert=self.invert.value,
669+
)
670+
except Exception as a:
671+
print(f"Unable to create masks. Check your module settings. {a}")
672+
finally:
673+
if self.use_gpu.value and model.torch:
674+
cleanup()
675+
642676
if self.cellpose_version.value == 'v2':
643-
assert int(self.cellpose_ver[0])<=2, "Cellpose version selected in RunCellpose module doesn't match version in Python"
644-
if float(self.cellpose_ver[0:3]) >= 0.6 and int(self.cellpose_ver[0])<2:
645-
if self.mode.value != 'custom':
646-
model = models.Cellpose(model_type= self.mode.value,
647-
gpu=self.use_gpu.value)
648-
else:
649-
model_file, model_directory, model_path = get_custom_model_vars(self)
650-
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
651-
677+
assert int(self.cellpose_ver[0])==2, "Cellpose version selected in RunCellpose module doesn't match version in Python"
678+
if self.mode.value != 'custom':
679+
model = models.CellposeModel(model_type= self.mode.value,
680+
gpu=self.use_gpu.value)
652681
else:
653-
if self.mode.value != 'custom':
654-
model = models.CellposeModel(model_type= self.mode.value,
655-
gpu=self.use_gpu.value)
656-
else:
657-
model_file, model_directory, model_path = get_custom_model_vars(self)
658-
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
659-
682+
model_file, model_directory, model_path = get_custom_model_vars(self)
683+
model = models.CellposeModel(pretrained_model=model_path, gpu=self.use_gpu.value)
660684
try:
661-
if float(self.cellpose_ver[0:3]) >= 0.7 and int(self.cellpose_ver[0])<2:
662-
y_data, flows, *_ = model.eval(
663-
x_data,
664-
channels=channels,
665-
diameter=diam,
666-
net_avg=self.use_averaging.value,
667-
do_3D=self.do_3D.value,
668-
anisotropy=anisotropy,
669-
flow_threshold=self.flow_threshold.value,
670-
cellprob_threshold=self.cellprob_threshold.value,
671-
stitch_threshold=self.stitch_threshold.value,
672-
min_size=self.min_size.value,
673-
omni=self.omni.value,
674-
invert=self.invert.value,
675-
)
676-
else:
677685
y_data, flows, *_ = model.eval(
678686
x_data,
679687
channels=channels,
@@ -683,19 +691,15 @@ def run(self, workspace):
683691
anisotropy=anisotropy,
684692
flow_threshold=self.flow_threshold.value,
685693
cellprob_threshold=self.cellprob_threshold.value,
686-
stitch_threshold=self.stitch_threshold.value,
694+
stitch_threshold=self.stitch_threshold.value, # is ignored if do_3D=True
687695
min_size=self.min_size.value,
688696
invert=self.invert.value,
689697
)
690698
except Exception as a:
691699
print(f"Unable to create masks. Check your module settings. {a}")
692700
finally:
693701
if self.use_gpu.value and model.torch:
694-
# Try to clear some GPU memory for other worker processes.
695-
try:
696-
cuda.empty_cache()
697-
except Exception as e:
698-
print(f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. {e}")
702+
cleanup()
699703

700704
elif self.cellpose_version.value == 'v3':
701705
assert int(self.cellpose_ver[0])==3, "Cellpose version selected in RunCellpose module doesn't match version in Python"
@@ -749,7 +753,7 @@ def run(self, workspace):
749753
anisotropy=anisotropy,
750754
flow_threshold=self.flow_threshold.value,
751755
cellprob_threshold=self.cellprob_threshold.value,
752-
stitch_threshold=self.stitch_threshold.value,
756+
stitch_threshold=self.stitch_threshold.value, # is ignored if do_3D=True
753757
min_size=self.min_size.value,
754758
invert=self.invert.value,
755759
)
@@ -762,15 +766,54 @@ def run(self, workspace):
762766
print(f"Unable to create masks. Check your module settings. {a}")
763767
finally:
764768
if self.use_gpu.value and model.torch:
765-
# Try to clear some GPU memory for other worker processes.
766-
try:
767-
cuda.empty_cache()
768-
except Exception as e:
769-
print(f"Unable to clear GPU memory. You may need to restart CellProfiler to change models. {e}")
769+
cleanup()
770770

771771
elif self.cellpose_version.value == 'v4':
772772
assert int(self.cellpose_ver[0])==4, "Cellpose version selected in RunCellpose module doesn't match version in Python"
773-
# TODO
773+
if self.mode.value == 'custom':
774+
model_file, model_directory, model_path = get_custom_model_vars(self)
775+
model_params = (self.mode.value, self.use_gpu.value)
776+
LOGGER.info(f"Loading new model: {self.mode.value}")
777+
self.current_model = models.CellposeModel(gpu=self.use_gpu.value)
778+
self.current_model_params = model_params
779+
780+
if self.specify_diameter.value:
781+
try:
782+
y_data, flows, *_ = self.current_model.eval(
783+
x_data,
784+
diameter=diam,
785+
do_3D=self.do_3D.value,
786+
anisotropy=anisotropy,
787+
flow_threshold=self.flow_threshold.value,
788+
cellprob_threshold=self.cellprob_threshold.value,
789+
stitch_threshold=self.stitch_threshold.value, # is ignored if do_3D=True
790+
min_size=self.min_size.value,
791+
invert=self.invert.value,
792+
)
793+
794+
except Exception as a:
795+
print(f"Unable to create masks. Check your module settings. {a}")
796+
finally:
797+
if self.use_gpu.value and model.torch:
798+
cleanup()
799+
else:
800+
try:
801+
y_data, flows, *_ = self.current_model.eval(
802+
x_data,
803+
do_3D=self.do_3D.value,
804+
anisotropy=anisotropy,
805+
flow_threshold=self.flow_threshold.value,
806+
cellprob_threshold=self.cellprob_threshold.value,
807+
stitch_threshold=self.stitch_threshold.value, # is ignored if do_3D=True
808+
min_size=self.min_size.value,
809+
invert=self.invert.value,
810+
)
811+
812+
except Exception as a:
813+
print(f"Unable to create masks. Check your module settings. {a}")
814+
finally:
815+
if self.use_gpu.value and model.torch:
816+
cleanup()
774817

775818
if self.remove_edge_masks:
776819
y_data = utils.remove_edge_masks(y_data)

0 commit comments

Comments
 (0)