Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .github/workflows/test-pip-installation.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ jobs:
- name: Install MobileSAM
run: python -m pip install git+https://github.com/ChaoningZhang/MobileSAM.git

# SAM2 (hvit_* backend) is not on PyPI; install it from GitHub to keep SAM2 coverage.
- name: Install SAM2
run: python -m pip install git+https://github.com/facebookresearch/sam2.git

# A standalone mask prompt is build-dependent with the PyPI torch wheels used here;
# MICRO_SAM_SKIP_STANDALONE_MASK_PROMPT skips that single assertion (see the test).
# Set it via GITHUB_ENV so it is exported to the headless-gui step's environment.
Expand Down
1 change: 1 addition & 0 deletions environment.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,4 @@ dependencies:
- zarr
- pip:
- git+https://github.com/ChaoningZhang/MobileSAM.git
- git+https://github.com/facebookresearch/sam2.git
114 changes: 58 additions & 56 deletions micro_sam/precompute_state.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Precompute image embeddings and automatic mask generator state for image data.
"""Precompute and cache the SAM2 image embeddings for image data.
"""

import os
Expand All @@ -22,7 +22,7 @@
from tqdm import tqdm

from . import util
from .v1.util import get_sam_model, precompute_image_embeddings, get_model_names
from .v1.util import precompute_image_embeddings
from .v1 import instance_segmentation


Expand Down Expand Up @@ -227,67 +227,82 @@ def precompute_state(
input_path: Union[os.PathLike, str],
output_path: Union[os.PathLike, str],
pattern: Optional[str] = None,
model_type: str = util._DEFAULT_MODEL,
model_type: str = "hvit_t",
checkpoint_path: Optional[Union[os.PathLike, str]] = None,
key: Optional[str] = None,
ndim: Optional[int] = None,
tile_shape: Optional[Tuple[int, int]] = None,
halo: Optional[Tuple[int, int]] = None,
precompute_amg_state: bool = False,
) -> None:
"""Precompute the image embeddings and other optional state for the input image(s).
"""Precompute and cache the SAM2 image embeddings for the input image(s).

The embeddings are saved in the same zarr format the annotators use, so the output can be loaded
directly by the `micro_sam.annotator` CLI and the napari GUI by passing the same path as the
embedding path (with a matching model and image).

Args:
input_path: The input image file(s). Can either be a single image file (e.g. tif or png),
a container file (e.g. hdf5 or zarr) or a folder with images files.
a container file (e.g. hdf5 or zarr) or a folder with image files.
In case of a container file the argument `key` must be given. In case of a folder
it can be given to provide a glob pattern to subselect files from the folder.
output_path: The output path where the embeddings and other state will be saved.
the `pattern` argument must be given to subselect files.
output_path: The output path where the embeddings will be saved. For a single input this is the path
to the embeddings zarr; for a folder of inputs this is the directory the embeddings are saved in.
pattern: Glob pattern to select files in a folder. The embeddings will be computed
for each of these files. To select all files in a folder pass "*".
model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default.
model_type: The SAM2 model to use. By default the `hvit_t` model is used.
checkpoint_path: Path to a checkpoint for a custom model.
key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr)
key: The key to the input file. This is needed for container files (e.g. hdf5 or zarr)
or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case.
ndim: The dimensionality of the data. By default, computes it from the input data.
tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling.
halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling.
precompute_amg_state: Whether to precompute the state for automatic instance segmentation
in addition to the image embeddings.
ndim: The dimensionality of the data. By default, computed from the input data.
"""
predictor, state = get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True)
from micro_sam.v2.util import precompute_image_embeddings, SUPPORTED_MODELS
# Imported lazily to avoid a circular import ('_state' imports from this module).
from micro_sam.sam_annotator._state import _get_sam_model

if "decoder_state" in state:
decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"])
else:
decoder = None

# Check if we precompute the state for a single file or for a folder with image files.
if pattern is None:
_precompute_state_for_file(
predictor, input_path, output_path, key,
ndim=ndim, tile_shape=tile_shape, halo=halo,
precompute_amg_state=precompute_amg_state,
decoder=decoder, verbose=True,
if not model_type.startswith("h"):
raise ValueError(
f"Embedding precomputation only supports SAM2 models ({', '.join(SUPPORTED_MODELS)}), got '{model_type}'."
)

# Determine the input files and matching output embedding paths.
single = pattern is None
if single:
input_files, output_paths = [input_path], [output_path]
else:
input_files = glob(os.path.join(input_path, pattern))
_precompute_state_for_files(
predictor, input_files, output_path, key=key,
ndim=ndim, tile_shape=tile_shape, halo=halo,
precompute_amg_state=precompute_amg_state,
decoder=decoder,
input_files = sorted(glob(os.path.join(input_path, pattern)))
if len(input_files) == 0:
raise ValueError(f"Could not find any files matching the pattern '{pattern}' in '{input_path}'.")
os.makedirs(output_path, exist_ok=True)
output_paths = [os.path.join(output_path, os.path.basename(f)) for f in input_files]

predictor, current_ndim = None, None
for input_file, out_path in tqdm(
zip(input_files, output_paths), total=len(input_files), desc="Precompute embeddings", disable=single
):
image_data = input_file if isinstance(input_file, np.ndarray) else util.load_image_data(input_file, key)
file_ndim = image_data.ndim if ndim is None else ndim

# Build the SAM2 predictor for the data dimensionality (2d image vs. 3d video predictor).
# We reuse the annotator's model loader so the embeddings match what the GUI / CLI expect.
if predictor is None or file_ndim != current_ndim:
predictor, _ = _get_sam_model(
model_type=model_type, ndim=file_ndim, device=None,
checkpoint_path=checkpoint_path, decoder_path=None, use_cli=True,
)
current_ndim = file_ndim

save_path = str(Path(out_path).with_suffix(".zarr"))
precompute_image_embeddings(
predictor=predictor, input_=image_data, save_path=save_path, ndim=file_ndim, verbose=single
)


def main():
"""@private"""
import argparse
from micro_sam.v2.util import SUPPORTED_MODELS, _DEFAULT_MODEL

available_models = list(get_model_names())
available_models = ", ".join(available_models)
available_models = ", ".join(SUPPORTED_MODELS)

parser = argparse.ArgumentParser(description="Compute the embeddings for an image.")
parser = argparse.ArgumentParser(description="Precompute and cache the SAM2 image embeddings for image data.")
parser.add_argument(
"-i", "--input_path", required=True,
help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) "
Expand All @@ -305,36 +320,23 @@ def main():
"for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'."
)
parser.add_argument(
"-m", "--model_type", default=util._DEFAULT_MODEL,
help=f"The segment anything model that will be used, one of {available_models}."
)
parser.add_argument(
"-c", "--checkpoint", default=None,
help="Checkpoint from which the SAM model will be loaded."
"-m", "--model_type", default=_DEFAULT_MODEL,
help=f"The SAM2 model that will be used, one of {available_models}."
)
parser.add_argument(
"--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None
)
parser.add_argument(
"--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None
"-c", "--checkpoint", default=None, help="Checkpoint from which the SAM2 model will be loaded."
)
parser.add_argument(
"-n", "--ndim", type=int, default=None,
help="The number of spatial dimensions in the data. "
"Please specify this if your data has a channel dimension."
)
parser.add_argument(
"-p", "--precompute_amg_state", action="store_true",
help="Whether to precompute the state for automatic instance segmentation."
)

args = parser.parse_args()
precompute_state(
args.input_path, args.embedding_path,
model_type=args.model_type, checkpoint_path=args.checkpoint,
pattern=args.pattern, key=args.key,
tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim,
precompute_amg_state=args.precompute_amg_state,
pattern=args.pattern, key=args.key, ndim=args.ndim,
)


Expand Down
15 changes: 12 additions & 3 deletions micro_sam/sam_annotator/_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,10 @@ def _get_sam_model(model_type, ndim, device, checkpoint_path, decoder_path, use_
if model_type.startswith("h"): # i.e. SAM2 models.
from micro_sam.v2.util import get_sam2_model

# 'device=None' lets 'get_sam2_model' auto-detect the best device (cuda > mps > cpu);
# an explicit device (e.g. from the '--device' CLI argument) is forwarded and honored.
if ndim == 2: # Get the SAM2 model and prepare the image predictor.
model = get_sam2_model(model_type=model_type, input_type="images")
model = get_sam2_model(model_type=model_type, input_type="images", device=device)
# Prepare the SAM2 predictor.
from sam2.sam2_image_predictor import SAM2ImagePredictor
predictor = SAM2ImagePredictor(model)
Expand All @@ -54,7 +56,7 @@ def _get_sam_model(model_type, ndim, device, checkpoint_path, decoder_path, use_
predictor.model_type = model_type
predictor.model_name = model_type
elif ndim == 3: # Get SAM2 video predictor
predictor = get_sam2_model(model_type=model_type, input_type="videos")
predictor = get_sam2_model(model_type=model_type, input_type="videos", device=device)
else:
raise ValueError
state = {}
Expand Down Expand Up @@ -172,6 +174,12 @@ def initialize_predictor(
else:
_comp_embed_fn = precompute_image_embeddings

# For SAM2 volumes, load the embeddings lazily from the zarr so the high-resolution
# per-slice features stay on disk and are streamed one slice at a time during tracking.
# This keeps memory bounded for large volumes (materialising all slices costs
# ~200 MB/slice and OOMs); it only applies when the embeddings are cached on disk.
lazy_loading = self.is_sam2 and ndim == 3 and isinstance(save_path, str)

self.image_embeddings = _comp_embed_fn(
predictor=self.predictor,
input_=image_data,
Expand All @@ -180,6 +188,7 @@ def initialize_predictor(
tile_shape=tile_shape,
halo=halo,
verbose=True,
lazy_loading=lazy_loading,
pbar_init=pbar_init,
pbar_update=pbar_update,
)
Expand All @@ -189,7 +198,7 @@ def initialize_predictor(
if self.is_sam2 and ndim == 3:
from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D
self.interactive_segmenter = PromptableSegmentation3D(
predictor=self.predictor, volume=image_data, volume_embeddings=self.image_embeddings,
predictor=self.predictor, volume=image_data, volume_embeddings=self.image_embeddings, device=device,
)

# If we have an embedding path the data signature has already been computed,
Expand Down
10 changes: 8 additions & 2 deletions micro_sam/sam_annotator/_widgets.py
Original file line number Diff line number Diff line change
Expand Up @@ -1382,8 +1382,10 @@ def __init__(self, parent=None):
# Section 1: Image and Model.
section1_layout = QtWidgets.QHBoxLayout()
section1_layout.addLayout(self._create_image_section())
# Default to the natural-image SAM2 family. The widget encodes the default choice as
# 'vit_<size><suffix>', so 'vit_t_sam2' selects 'Natural Images (SAM2)' at the tiny size.
section1_layout.addLayout(
self._create_model_section()
self._create_model_section(default_model="vit_t_sam2")
) # Creates the model family widget section.
self.layout().addLayout(section1_layout)

Expand Down Expand Up @@ -2157,7 +2159,11 @@ def volumetric_segmentation_impl():
)

# Propagate the prompts throughout the volume and combine the propagated segmentations.
seg = state.interactive_segmenter.predict()
# Report per-slice progress so the user can see the propagation advancing.
pbar_signals.pbar_description.emit("Propagate in volume")
seg = state.interactive_segmenter.predict(
update_progress=lambda update: pbar_signals.pbar_update.emit(update)
)

else:
# Step 1: Segment all slices with prompts.
Expand Down
23 changes: 13 additions & 10 deletions micro_sam/sam_annotator/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ def toggle_label(prompts):

def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True):

available_models = list(get_model_names())
available_models = ", ".join(available_models)
from micro_sam.v2.util import SUPPORTED_MODELS
available_models = ", ".join(SUPPORTED_MODELS + list(get_model_names()))

parser = argparse.ArgumentParser(description=description)

Expand Down Expand Up @@ -78,7 +78,7 @@ def _initialize_parser(description, with_segmentation_result=True, with_instance
)

parser.add_argument(
"-m", "--model_type", default=util._DEFAULT_MODEL,
"-m", "--model_type", default="hvit_t",
help=f"The segment anything model that will be used, one of {available_models}."
)
parser.add_argument(
Expand Down Expand Up @@ -687,14 +687,17 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic
"histopathology": "Histopathology",
}

model_family = "Natural Images (SAM)" # If no suffix patterns match, stick to 'Natural Images (SAM)' family.
for k, v in supported_dropdown_maps.items():
if model_type.endswith(k):
model_family = v
break
if model_type.startswith("hvit"): # SAM2 models, eg. 'hvit_t', are all natural-image models.
model_family = "Natural Images (SAM2)"
else:
model_family = "Natural Images (SAM)" # If no suffix patterns match, stick to 'Natural Images (SAM)' family.
for k, v in supported_dropdown_maps.items():
if model_type.endswith(k):
model_family = v
break

index = widget.model_family_dropdown.findText(model_family)
if index > 0:
if index >= 0:
widget.model_family_dropdown.setCurrentIndex(index)

# Update the index for model size, eg. 'base', 'tiny', etc.
Expand All @@ -703,7 +706,7 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic
model_size = size_map[model_type[size_idx]]

index = widget.model_size_dropdown.findText(model_size)
if index > 0:
if index >= 0:
widget.model_size_dropdown.setCurrentIndex(index)

if save_path is not None and isinstance(save_path, str):
Expand Down
9 changes: 8 additions & 1 deletion micro_sam/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def _load_checkpoint(checkpoint_path):
#


def _to_image(image):
def _to_image(image, normalization="minmax"):
input_ = image
ndim = input_.ndim
n_channels = 1 if ndim == 2 else input_.shape[-1]
Expand All @@ -225,6 +225,13 @@ def _to_image(image):
)
assert input_.ndim == 3 and input_.shape[-1] == 3

if normalization == "percentile":
# Percentile-normalize each channel to [0, 1], matching the SAM2 3D frame normalization.
# Clip since percentile normalization maps the 2nd / 98th percentiles to 0 / 1 and overshoots.
from torch_em.transform.raw import normalize_percentile
input_ = normalize_percentile(input_.astype("float32"), lower=2.0, upper=98.0, axis=(0, 1))
return np.clip(np.array(input_), 0.0, 1.0)

# Normalize the input per channel and bring it to uint8.
input_ = input_.astype("float32")
input_ -= input_.min(axis=(0, 1))[None, None]
Expand Down
Loading
Loading