From 44b73874420065e041f7a1ba14195e3b46d387d9 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 16 Jun 2026 15:36:16 +0200 Subject: [PATCH 1/5] Enable precompute embeddings CLI using SAM2 models --- micro_sam/precompute_state.py | 112 +++++++++++----------- micro_sam/sam_annotator/util.py | 6 +- micro_sam/util.py | 9 +- micro_sam/v2/models/_video_predictor.py | 18 ++-- micro_sam/v2/prompt_based_segmentation.py | 19 +++- micro_sam/v2/util.py | 2 +- 6 files changed, 93 insertions(+), 73 deletions(-) diff --git a/micro_sam/precompute_state.py b/micro_sam/precompute_state.py index 04f21fc1..b048170b 100644 --- a/micro_sam/precompute_state.py +++ b/micro_sam/precompute_state.py @@ -1,4 +1,4 @@ -"""Precompute image embeddings and automatic mask generator state for image data. +"""Precompute and cache the SAM2 image embeddings for image data. """ import os @@ -225,67 +225,82 @@ def precompute_state( input_path: Union[os.PathLike, str], output_path: Union[os.PathLike, str], pattern: Optional[str] = None, - model_type: str = util._DEFAULT_MODEL, + model_type: str = "hvit_t", checkpoint_path: Optional[Union[os.PathLike, str]] = None, key: Optional[str] = None, ndim: Optional[int] = None, - tile_shape: Optional[Tuple[int, int]] = None, - halo: Optional[Tuple[int, int]] = None, - precompute_amg_state: bool = False, ) -> None: - """Precompute the image embeddings and other optional state for the input image(s). + """Precompute and cache the SAM2 image embeddings for the input image(s). + + The embeddings are saved in the same zarr format the annotators use, so the output can be loaded + directly by the `micro_sam.annotator` CLI and the napari GUI by passing the same path as the + embedding path (with a matching model and image). Args: input_path: The input image file(s). Can either be a single image file (e.g. tif or png), - a container file (e.g. hdf5 or zarr) or a folder with images files. + a container file (e.g. hdf5 or zarr) or a folder with image files. In case of a container file the argument `key` must be given. In case of a folder - it can be given to provide a glob pattern to subselect files from the folder. - output_path: The output path where the embeddings and other state will be saved. + the `pattern` argument must be given to subselect files. + output_path: The output path where the embeddings will be saved. For a single input this is the path + to the embeddings zarr; for a folder of inputs this is the directory the embeddings are saved in. pattern: Glob pattern to select files in a folder. The embeddings will be computed for each of these files. To select all files in a folder pass "*". - model_type: The Segment Anything model to use. Will use the `vit_b_lm` model by default. + model_type: The SAM2 model to use. By default the `hvit_t` model is used. checkpoint_path: Path to a checkpoint for a custom model. - key: The key to the input file. This is needed for contaner files (e.g. hdf5 or zarr) + key: The key to the input file. This is needed for container files (e.g. hdf5 or zarr) or to load several images as 3d volume. Provide a glob pattern, e.g. "*.tif", for this case. - ndim: The dimensionality of the data. By default, computes it from the input data. - tile_shape: Shape of tiles for tiled prediction. By default prediction is run without tiling. - halo: Overlap of the tiles for tiled prediction. By default prediction is run without tiling. - precompute_amg_state: Whether to precompute the state for automatic instance segmentation - in addition to the image embeddings. + ndim: The dimensionality of the data. By default, computed from the input data. """ - predictor, state = util.get_sam_model(model_type=model_type, checkpoint_path=checkpoint_path, return_state=True) + from micro_sam.v2.util import precompute_image_embeddings, SUPPORTED_MODELS + # Imported lazily to avoid a circular import ('_state' imports from this module). + from micro_sam.sam_annotator._state import _get_sam_model - if "decoder_state" in state: - decoder = instance_segmentation.get_decoder(predictor.model.image_encoder, state["decoder_state"]) - else: - decoder = None - - # Check if we precompute the state for a single file or for a folder with image files. - if pattern is None: - _precompute_state_for_file( - predictor, input_path, output_path, key, - ndim=ndim, tile_shape=tile_shape, halo=halo, - precompute_amg_state=precompute_amg_state, - decoder=decoder, verbose=True, + if not model_type.startswith("h"): + raise ValueError( + f"Embedding precomputation only supports SAM2 models ({', '.join(SUPPORTED_MODELS)}), got '{model_type}'." ) + + # Determine the input files and matching output embedding paths. + single = pattern is None + if single: + input_files, output_paths = [input_path], [output_path] else: - input_files = glob(os.path.join(input_path, pattern)) - _precompute_state_for_files( - predictor, input_files, output_path, key=key, - ndim=ndim, tile_shape=tile_shape, halo=halo, - precompute_amg_state=precompute_amg_state, - decoder=decoder, + input_files = sorted(glob(os.path.join(input_path, pattern))) + if len(input_files) == 0: + raise ValueError(f"Could not find any files matching the pattern '{pattern}' in '{input_path}'.") + os.makedirs(output_path, exist_ok=True) + output_paths = [os.path.join(output_path, os.path.basename(f)) for f in input_files] + + predictor, current_ndim = None, None + for input_file, out_path in tqdm( + zip(input_files, output_paths), total=len(input_files), desc="Precompute embeddings", disable=single + ): + image_data = input_file if isinstance(input_file, np.ndarray) else util.load_image_data(input_file, key) + file_ndim = image_data.ndim if ndim is None else ndim + + # Build the SAM2 predictor for the data dimensionality (2d image vs. 3d video predictor). + # We reuse the annotator's model loader so the embeddings match what the GUI / CLI expect. + if predictor is None or file_ndim != current_ndim: + predictor, _ = _get_sam_model( + model_type=model_type, ndim=file_ndim, device=None, + checkpoint_path=checkpoint_path, decoder_path=None, use_cli=True, + ) + current_ndim = file_ndim + + save_path = str(Path(out_path).with_suffix(".zarr")) + precompute_image_embeddings( + predictor=predictor, input_=image_data, save_path=save_path, ndim=file_ndim, verbose=single ) def main(): """@private""" import argparse + from micro_sam.v2.util import SUPPORTED_MODELS, _DEFAULT_MODEL - available_models = list(util.get_model_names()) - available_models = ", ".join(available_models) + available_models = ", ".join(SUPPORTED_MODELS) - parser = argparse.ArgumentParser(description="Compute the embeddings for an image.") + parser = argparse.ArgumentParser(description="Precompute and cache the SAM2 image embeddings for image data.") parser.add_argument( "-i", "--input_path", required=True, help="The filepath to the image data. Supports all data types that can be read by imageio (e.g. tif, png, ...) " @@ -303,36 +318,23 @@ def main(): "for an image stack it is a wild-card, e.g. '*.png' and for mrc it is 'data'." ) parser.add_argument( - "-m", "--model_type", default=util._DEFAULT_MODEL, - help=f"The segment anything model that will be used, one of {available_models}." - ) - parser.add_argument( - "-c", "--checkpoint", default=None, - help="Checkpoint from which the SAM model will be loaded." + "-m", "--model_type", default=_DEFAULT_MODEL, + help=f"The SAM2 model that will be used, one of {available_models}." ) parser.add_argument( - "--tile_shape", nargs="+", type=int, help="The tile shape for using tiled prediction.", default=None - ) - parser.add_argument( - "--halo", nargs="+", type=int, help="The halo for using tiled prediction.", default=None + "-c", "--checkpoint", default=None, help="Checkpoint from which the SAM2 model will be loaded." ) parser.add_argument( "-n", "--ndim", type=int, default=None, help="The number of spatial dimensions in the data. " "Please specify this if your data has a channel dimension." ) - parser.add_argument( - "-p", "--precompute_amg_state", action="store_true", - help="Whether to precompute the state for automatic instance segmentation." - ) args = parser.parse_args() precompute_state( args.input_path, args.embedding_path, model_type=args.model_type, checkpoint_path=args.checkpoint, - pattern=args.pattern, key=args.key, - tile_shape=args.tile_shape, halo=args.halo, ndim=args.ndim, - precompute_amg_state=args.precompute_amg_state, + pattern=args.pattern, key=args.key, ndim=args.ndim, ) diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index 092cf5dd..264205cb 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -40,8 +40,8 @@ def toggle_label(prompts): def _initialize_parser(description, with_segmentation_result=True, with_instance_segmentation=True): - available_models = list(util.get_model_names()) - available_models = ", ".join(available_models) + from micro_sam.v2.util import SUPPORTED_MODELS + available_models = ", ".join(SUPPORTED_MODELS + list(util.get_model_names())) parser = argparse.ArgumentParser(description=description) @@ -76,7 +76,7 @@ def _initialize_parser(description, with_segmentation_result=True, with_instance ) parser.add_argument( - "-m", "--model_type", default=util._DEFAULT_MODEL, + "-m", "--model_type", default="hvit_t", help=f"The segment anything model that will be used, one of {available_models}." ) parser.add_argument( diff --git a/micro_sam/util.py b/micro_sam/util.py index 010d442b..aed83e32 100644 --- a/micro_sam/util.py +++ b/micro_sam/util.py @@ -615,7 +615,7 @@ def get_model_names() -> Iterable: # -def _to_image(image): +def _to_image(image, normalization="minmax"): input_ = image ndim = input_.ndim n_channels = 1 if ndim == 2 else input_.shape[-1] @@ -640,6 +640,13 @@ def _to_image(image): ) assert input_.ndim == 3 and input_.shape[-1] == 3 + if normalization == "percentile": + # Percentile-normalize each channel to [0, 1], matching the SAM2 3D frame normalization. + # Clip since percentile normalization maps the 2nd / 98th percentiles to 0 / 1 and overshoots. + from torch_em.transform.raw import normalize_percentile + input_ = normalize_percentile(input_.astype("float32"), lower=2.0, upper=98.0, axis=(0, 1)) + return np.clip(np.array(input_), 0.0, 1.0) + # Normalize the input per channel and bring it to uint8. input_ = input_.astype("float32") input_ -= input_.min(axis=(0, 1))[None, None] diff --git a/micro_sam/v2/models/_video_predictor.py b/micro_sam/v2/models/_video_predictor.py index fb426a38..2f3029b6 100644 --- a/micro_sam/v2/models/_video_predictor.py +++ b/micro_sam/v2/models/_video_predictor.py @@ -18,9 +18,9 @@ def _load_img_as_tensor(img_path, image_size): """Load a single frame as a float32 [0, 1] tensor of shape (3, image_size, image_size). For file-path inputs: PIL loads the image, resizes via plain square stretch (JPEG convention). - For numpy inputs: accepts uint8 [0, 255] or float32 [0, 1]; resizes using aspect-ratio - preserving scale to image_size on the longest side, then zero-pads to a square - matching - ConvertToSam2VideoBatch._to_sam2_size used during training. + For numpy inputs: percentile-normalizes any dtype to [0, 1] (2nd / 98th percentile per channel); + resizes using aspect-ratio preserving scale to image_size on the longest side, then zero-pads to a + square - matching ConvertToSam2VideoBatch._to_sam2_size used during training. Returns: img: (3, image_size, image_size) float32 tensor, ImageNet-normalised by the caller. @@ -37,12 +37,12 @@ def _load_img_as_tensor(img_path, image_size): img_np = img_path img_np = np.stack([img_np] * 3, axis=-1) if img_np.ndim == 2 else img_np - if img_np.dtype == np.uint8: - img_np = img_np.astype(np.float32) / 255.0 - elif img_np.dtype in (np.float32, np.float64): - img_np = img_np.astype(np.float32) - else: - raise RuntimeError(f"Unsupported image dtype: {img_np.dtype}") + # Percentile-normalize each channel to [0, 1], so any input dtype (e.g. uint16 microscopy data) + # is mapped to the range SAM2's ImageNet normalization expects. Clip since percentile + # normalization maps the 2nd / 98th percentiles to 0 / 1 and overshoots outside that range. + from torch_em.transform.raw import normalize_percentile + img_np = normalize_percentile(img_np.astype(np.float32), lower=2.0, upper=98.0, axis=(0, 1)) + img_np = np.clip(img_np, 0.0, 1.0) # Aspect-ratio preserving scale + zero-pad, matching _to_sam2_size in training. # video_height/video_width are set to max(H, W) so SAM2's coordinate normalization diff --git a/micro_sam/v2/prompt_based_segmentation.py b/micro_sam/v2/prompt_based_segmentation.py index 88955a6d..e9c0b9a1 100644 --- a/micro_sam/v2/prompt_based_segmentation.py +++ b/micro_sam/v2/prompt_based_segmentation.py @@ -5,6 +5,16 @@ from micro_sam.prompt_based_segmentation import _process_box +def _crop_to_original_shape(mask, shape): + """Crop a SAM2 video-predictor mask back to the original slice shape. + + The video predictor pads non-square frames to a square of side max(H, W) (padding appended at the + bottom/right) and returns masks at that padded size. The image content occupies the top-left + [0:H, 0:W] region, so cropping recovers the original (H, W) mask. For square volumes this is a no-op. + """ + return mask[:shape[0], :shape[1]] + + def promptable_segmentation_2d( predictor, image: Optional[np.ndarray] = None, @@ -202,7 +212,7 @@ def promptable_segmentation_3d( for slice_idx in video_segments.keys(): per_slice_seg = np.zeros(volume.shape[-2:]) for _instance_idx, _instance_mask in video_segments[slice_idx].items(): - per_slice_seg[_instance_mask.squeeze()] = _instance_idx + per_slice_seg[_crop_to_original_shape(_instance_mask.squeeze(), volume.shape[-2:])] = _instance_idx segmentation.append(per_slice_seg) segmentation = (np.stack(segmentation) > 0).astype("uint64") @@ -528,8 +538,8 @@ def segment_slice( mask_logits = out_mask_logits[0] # Get first object seg = (mask_logits.squeeze() > 0.0).cpu().numpy() - # Ensure correct output type - seg = seg.astype("uint32") + # Crop back to the original slice shape (the video predictor pads non-square frames). + seg = _crop_to_original_shape(seg, self.volume.shape[-2:]).astype("uint32") finally: # Reset the state to clear this object's prompts @@ -547,7 +557,8 @@ def predict(self): for slice_idx in video_segments.keys(): per_slice_seg = np.zeros(self.volume.shape[-2:]) for _instance_idx, _instance_mask in video_segments[slice_idx].items(): - per_slice_seg[_instance_mask.squeeze()] = _instance_idx + mask = _crop_to_original_shape(_instance_mask.squeeze(), self.volume.shape[-2:]) + per_slice_seg[mask] = _instance_idx segmentation.append(per_slice_seg) segmentation = np.stack(segmentation).astype("uint64") diff --git a/micro_sam/v2/util.py b/micro_sam/v2/util.py index d5dbc3df..836e1b77 100644 --- a/micro_sam/v2/util.py +++ b/micro_sam/v2/util.py @@ -210,7 +210,7 @@ def _compute_2d(input_, predictor, f, save_path, pbar_init, pbar_update): predictor.reset_predictor() from micro_sam.util import _to_image - predictor.set_image(_to_image(input_)) + predictor.set_image(_to_image(input_, normalization="percentile")) features = predictor.get_image_embedding().cpu().numpy() high_res_features = predictor._features.get("high_res_feats") original_size = predictor._orig_hw From d91cc3e0f3772663f192f341ed50e977cc7b5382 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Tue, 16 Jun 2026 15:44:10 +0200 Subject: [PATCH 2/5] Fix precompute embeddings CLI --- test/test_cli.py | 25 ++++++++----------------- 1 file changed, 8 insertions(+), 17 deletions(-) diff --git a/test/test_cli.py b/test/test_cli.py index 05ad34dd..453ec4a6 100644 --- a/test/test_cli.py +++ b/test/test_cli.py @@ -46,43 +46,34 @@ def test_precompute_embeddings(self): image_data = binary_blobs(512).astype("uint8") * 255 imageio.imwrite(im_path, image_data) - # Test precomputation with a single image. + # Test precomputation with a single (2d) image. emb_path1 = os.path.join(self.tmp_folder, "embedddings1.zarr") run([ - "micro_sam.precompute_embeddings", "-i", im_path, "-e", emb_path1, - "-m", self.model_type, "--precompute_amg_state" + "micro_sam.precompute_embeddings", "-i", im_path, "-e", emb_path1, "-m", "hvit_t", ]) self.assertTrue(os.path.exists(emb_path1)) f = zarr.open(emb_path1, mode="r") self.assertIn("features", f) + self.assertIn("high_res_feats", f) - ais_path = os.path.join(emb_path1, "is_state.h5") - self.assertTrue(os.path.exists(ais_path)) - - # Test precomputation with image stack. + # Test precomputation with an image stack (loaded as a 3d volume). emb_path2 = os.path.join(self.tmp_folder, "embedddings2.zarr") run([ - "micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path2, - "-m", self.model_type, "-k", "*.tif", "--precompute_amg_state" + "micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path2, "-m", "hvit_t", "-k", "*.tif", ]) self.assertTrue(os.path.exists(emb_path2)) f = zarr.open(emb_path2, mode="r") self.assertIn("features", f) self.assertEqual(f["features"].shape[0], n_images) - ais_path = os.path.join(emb_path2, "is_state.h5") - self.assertTrue(os.path.exists(ais_path)) - - # Test precomputation with pattern to process multiple image. + # Test precomputation with a pattern to process multiple images. emb_path3 = os.path.join(self.tmp_folder, "embedddings3") run([ - "micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path3, - "-m", self.model_type, "--pattern", "*.tif", "--precompute_amg_state" + "micro_sam.precompute_embeddings", "-i", self.tmp_folder, "-e", emb_path3, "-m", "hvit_t", + "--pattern", "*.tif", ]) for i in range(n_images): self.assertTrue(os.path.exists(os.path.join(emb_path3, f"image-{i}.zarr"))) - ais_path = os.path.join(emb_path3, f"image-{i}.zarr", "is_state.h5") - self.assertTrue(os.path.exists(ais_path)) @pytest.mark.skipif(platform.system() == "Windows", reason="CLI test is not working on windows.") def test_automatic_segmentation(self): From 88d35451c1dc50c2d7fb090d2025464b50341c07 Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 17 Jun 2026 13:20:12 +0200 Subject: [PATCH 3/5] Add sam2 dependency for CIs and more SAM2 embeddings tests --- environment.yaml | 1 + test/test_util.py | 108 ++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 109 insertions(+) diff --git a/environment.yaml b/environment.yaml index fb94c049..b6df8c4e 100644 --- a/environment.yaml +++ b/environment.yaml @@ -33,3 +33,4 @@ dependencies: - zarr - pip: - git+https://github.com/ChaoningZhang/MobileSAM.git + - git+https://github.com/facebookresearch/sam2.git diff --git a/test/test_util.py b/test/test_util.py index df64fc5f..36958808 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -277,5 +277,113 @@ def test_get_device(self): self.assertEqual(device.type, "cpu") +try: + import sam2 # noqa + SAM2_SUPPORT = True +except ImportError: + SAM2_SUPPORT = False + + +@unittest.skipUnless(SAM2_SUPPORT, "Requires the sam2 package.") +class TestSAM2Util(unittest.TestCase): + model_type = "hvit_t" + tmp_folder = "tmp-files-sam2" + + def setUp(self): + os.makedirs(self.tmp_folder, exist_ok=True) + + def tearDown(self): + rmtree(self.tmp_folder) + + def _get_predictor(self, ndim): + # Build the SAM2 predictor exactly as the precompute CLI / annotator do. + from micro_sam.sam_annotator._state import _get_sam_model + predictor, _ = _get_sam_model( + model_type=self.model_type, ndim=ndim, device="cpu", + checkpoint_path=None, decoder_path=None, use_cli=True, + ) + return predictor + + def _check_predictor_initialization_2d(self, predictor, embeddings): + from micro_sam.v2.util import set_precomputed + predictor.reset_predictor() + set_precomputed(predictor, embeddings) + self.assertTrue(predictor._is_image_set) + self.assertIsNotNone(predictor._features) + self.assertIsNotNone(predictor._orig_hw) + predictor.reset_predictor() + + def test_precompute_image_embeddings_2d(self): + from micro_sam.v2.util import precompute_image_embeddings + + predictor = self._get_predictor(ndim=2) + input_ = np.random.rand(512, 512).astype("float32") + + # Compute the image embeddings without save path. + embeddings = precompute_image_embeddings(predictor, input_, ndim=2) + for key in ("features", "high_res_feats", "input_size", "original_size"): + self.assertIn(key, embeddings) + self.assertEqual(embeddings["features"].ndim, 4) + self.assertEqual(embeddings["features"].shape, (1, 256, 64, 64)) + self._check_predictor_initialization_2d(predictor, embeddings) + + # Compute the image embeddings with save path. + save_path = os.path.join(self.tmp_folder, "embed.zarr") + embeddings = precompute_image_embeddings(predictor, input_, save_path=save_path, ndim=2) + self._check_predictor_initialization_2d(predictor, embeddings) + + # Check the contents of the saved embeddings. + self.assertTrue(os.path.exists(save_path)) + f = zarr.open(save_path, mode="r") + self.assertIn("features", f) + self.assertIn("high_res_feats", f) + self.assertEqual(f["features"].shape, (1, 256, 64, 64)) + # The signature is written so the GUI / CLI can validate a reload. + self.assertEqual(f.attrs["model_name"], self.model_type) + self.assertIn("data_signature", f.attrs) + + # Check that everything still works when we load the image embeddings from file. + embeddings = precompute_image_embeddings(predictor, input_, save_path=save_path, ndim=2) + self.assertEqual(embeddings["features"].shape, (1, 256, 64, 64)) + self._check_predictor_initialization_2d(predictor, embeddings) + + def test_precompute_image_embeddings_3d(self): + from micro_sam.v2.util import precompute_image_embeddings, set_precomputed + + predictor = self._get_predictor(ndim=3) + input_ = np.random.rand(2, 256, 256).astype("float32") + + def check_slices(embeddings): + for i in range(input_.shape[0]): + _, inference_state = set_precomputed(predictor, embeddings, i=i, input_=input_) + self.assertIn("cached_features", inference_state) + + # Compute the image embeddings without save path. + # Note: the in-memory form stacks the per-slice features along z (4 dims), + # while the saved form keeps an explicit singleton dim (5 dims). + embeddings = precompute_image_embeddings(predictor, input_, ndim=3) + for key in ("features", "pos_enc", "fpn", "input_size", "original_size"): + self.assertIn(key, embeddings) + self.assertEqual(embeddings["features"].shape[0], input_.shape[0]) + + # Compute the image embeddings with save path. + save_path = os.path.join(self.tmp_folder, "embed_3d.zarr") + embeddings = precompute_image_embeddings(predictor, input_, save_path=save_path, ndim=3) + check_slices(embeddings) + + # Check the contents of the saved embeddings. + self.assertTrue(os.path.exists(save_path)) + f = zarr.open(save_path, mode="r") + self.assertIn("features", f) + self.assertIn("pos_enc", f) + self.assertIn("fpn", f) + self.assertEqual(f["features"].shape, (2, 1, 256, 64, 64)) + self.assertEqual(f.attrs["model_name"], self.model_type) + + # Check that everything still works when we load the image embeddings from file. + embeddings = precompute_image_embeddings(predictor, input_, save_path=save_path, ndim=3) + check_slices(embeddings) + + if __name__ == "__main__": unittest.main() From 5805751f0d00858220420cf251933b0af67ad4fa Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 17 Jun 2026 14:54:57 +0200 Subject: [PATCH 4/5] Fix SAM2 3d annotator for propagation feedback and better device handling --- micro_sam/sam_annotator/_state.py | 15 ++++- micro_sam/sam_annotator/_widgets.py | 10 +++- micro_sam/sam_annotator/util.py | 17 +++--- micro_sam/v2/models/_video_predictor.py | 67 ++++++++++++++++------- micro_sam/v2/prompt_based_segmentation.py | 27 +++++++-- 5 files changed, 100 insertions(+), 36 deletions(-) diff --git a/micro_sam/sam_annotator/_state.py b/micro_sam/sam_annotator/_state.py index 76066890..cb06e773 100644 --- a/micro_sam/sam_annotator/_state.py +++ b/micro_sam/sam_annotator/_state.py @@ -44,8 +44,10 @@ def _get_sam_model(model_type, ndim, device, checkpoint_path, decoder_path, use_ if model_type.startswith("h"): # i.e. SAM2 models. from micro_sam.v2.util import get_sam2_model + # 'device=None' lets 'get_sam2_model' auto-detect the best device (cuda > mps > cpu); + # an explicit device (e.g. from the '--device' CLI argument) is forwarded and honored. if ndim == 2: # Get the SAM2 model and prepare the image predictor. - model = get_sam2_model(model_type=model_type, input_type="images") + model = get_sam2_model(model_type=model_type, input_type="images", device=device) # Prepare the SAM2 predictor. from sam2.sam2_image_predictor import SAM2ImagePredictor predictor = SAM2ImagePredictor(model) @@ -54,7 +56,7 @@ def _get_sam_model(model_type, ndim, device, checkpoint_path, decoder_path, use_ predictor.model_type = model_type predictor.model_name = model_type elif ndim == 3: # Get SAM2 video predictor - predictor = get_sam2_model(model_type=model_type, input_type="videos") + predictor = get_sam2_model(model_type=model_type, input_type="videos", device=device) else: raise ValueError state = {} @@ -172,6 +174,12 @@ def initialize_predictor( else: _comp_embed_fn = precompute_image_embeddings + # For SAM2 volumes, load the embeddings lazily from the zarr so the high-resolution + # per-slice features stay on disk and are streamed one slice at a time during tracking. + # This keeps memory bounded for large volumes (materialising all slices costs + # ~200 MB/slice and OOMs); it only applies when the embeddings are cached on disk. + lazy_loading = self.is_sam2 and ndim == 3 and isinstance(save_path, str) + self.image_embeddings = _comp_embed_fn( predictor=self.predictor, input_=image_data, @@ -180,6 +188,7 @@ def initialize_predictor( tile_shape=tile_shape, halo=halo, verbose=True, + lazy_loading=lazy_loading, pbar_init=pbar_init, pbar_update=pbar_update, ) @@ -189,7 +198,7 @@ def initialize_predictor( if self.is_sam2 and ndim == 3: from micro_sam.v2.prompt_based_segmentation import PromptableSegmentation3D self.interactive_segmenter = PromptableSegmentation3D( - predictor=self.predictor, volume=image_data, volume_embeddings=self.image_embeddings, + predictor=self.predictor, volume=image_data, volume_embeddings=self.image_embeddings, device=device, ) # If we have an embedding path the data signature has already been computed, diff --git a/micro_sam/sam_annotator/_widgets.py b/micro_sam/sam_annotator/_widgets.py index 3f5e132c..a4153476 100644 --- a/micro_sam/sam_annotator/_widgets.py +++ b/micro_sam/sam_annotator/_widgets.py @@ -1382,8 +1382,10 @@ def __init__(self, parent=None): # Section 1: Image and Model. section1_layout = QtWidgets.QHBoxLayout() section1_layout.addLayout(self._create_image_section()) + # Default to the natural-image SAM2 family. The widget encodes the default choice as + # 'vit_', so 'vit_t_sam2' selects 'Natural Images (SAM2)' at the tiny size. section1_layout.addLayout( - self._create_model_section() + self._create_model_section(default_model="vit_t_sam2") ) # Creates the model family widget section. self.layout().addLayout(section1_layout) @@ -2157,7 +2159,11 @@ def volumetric_segmentation_impl(): ) # Propagate the prompts throughout the volume and combine the propagated segmentations. - seg = state.interactive_segmenter.predict() + # Report per-slice progress so the user can see the propagation advancing. + pbar_signals.pbar_description.emit("Propagate in volume") + seg = state.interactive_segmenter.predict( + update_progress=lambda update: pbar_signals.pbar_update.emit(update) + ) else: # Step 1: Segment all slices with prompts. diff --git a/micro_sam/sam_annotator/util.py b/micro_sam/sam_annotator/util.py index d7fc0dd2..025e7964 100644 --- a/micro_sam/sam_annotator/util.py +++ b/micro_sam/sam_annotator/util.py @@ -687,14 +687,17 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic "histopathology": "Histopathology", } - model_family = "Natural Images (SAM)" # If no suffix patterns match, stick to 'Natural Images (SAM)' family. - for k, v in supported_dropdown_maps.items(): - if model_type.endswith(k): - model_family = v - break + if model_type.startswith("hvit"): # SAM2 models, eg. 'hvit_t', are all natural-image models. + model_family = "Natural Images (SAM2)" + else: + model_family = "Natural Images (SAM)" # If no suffix patterns match, stick to 'Natural Images (SAM)' family. + for k, v in supported_dropdown_maps.items(): + if model_type.endswith(k): + model_family = v + break index = widget.model_family_dropdown.findText(model_family) - if index > 0: + if index >= 0: widget.model_family_dropdown.setCurrentIndex(index) # Update the index for model size, eg. 'base', 'tiny', etc. @@ -703,7 +706,7 @@ def _sync_embedding_widget(widget, model_type, save_path, checkpoint_path, devic model_size = size_map[model_type[size_idx]] index = widget.model_size_dropdown.findText(model_size) - if index > 0: + if index >= 0: widget.model_size_dropdown.setCurrentIndex(index) if save_path is not None and isinstance(save_path, str): diff --git a/micro_sam/v2/models/_video_predictor.py b/micro_sam/v2/models/_video_predictor.py index 2f3029b6..75d7500e 100644 --- a/micro_sam/v2/models/_video_predictor.py +++ b/micro_sam/v2/models/_video_predictor.py @@ -233,28 +233,57 @@ def init_state( inference_state["cached_features"] = {} # Create an empty 'cached_features' dictionary to warm up. return inference_state - # Visual features on all frames (slices) for faster interactions. - feats = volume_embeddings["features"] - pos_list = volume_embeddings["pos_enc"] - fpn_list = volume_embeddings["fpn"] - - # Embeddings have been provided. We just need to pass stuff to 'inference_state' as expected. - running_features = {} - for frame_idx in range(inference_state["num_frames"]): - image = images[frame_idx].to(device).float().unsqueeze(0) - - vision_features = torch.as_tensor(np.asarray(feats[frame_idx]), device=device).float() - vision_pos_enc = [torch.as_tensor(np.asarray(t[frame_idx]), device=device).float() for t in pos_list] - backbone_fpn = [torch.as_tensor(np.asarray(t[frame_idx]), device=device).float() for t in fpn_list] - backbone_out = { - "vision_features": vision_features, "vision_pos_enc": vision_pos_enc, "backbone_fpn": backbone_fpn, - } - running_features[frame_idx] = (image, backbone_out) - - inference_state["cached_features"] = running_features + # Store the precomputed embeddings and load each frame's features lazily during tracking + # (see '_get_image_feature'). Materialising every slice's high-resolution features up-front + # costs ~200 MB/slice and OOMs for large volumes; the lazy single-frame cache keeps memory + # bounded. When the embeddings are backed by a zarr on disk (lazy_loading=True), only one + # slice is held in memory at a time. + inference_state["precomputed_embeddings"] = volume_embeddings + inference_state["cached_features"] = {} return inference_state + def _get_image_feature(self, inference_state, frame_idx, batch_size): + """Compute or look up the image features for a frame. + + Overrides 'SAM2VideoPredictor._get_image_feature' to source per-frame features from the + precomputed embeddings (if stored on the inference state) instead of running the image + encoder. A single-frame cache bounds memory for large volumes. Falls back to the parent + behaviour (run the encoder) when no precomputed embeddings are available. + """ + image, backbone_out = inference_state["cached_features"].get(frame_idx, (None, None)) + if backbone_out is None: + embeddings = inference_state.get("precomputed_embeddings") + if embeddings is None: + return super()._get_image_feature(inference_state, frame_idx, batch_size) + + device = inference_state["device"] + image = inference_state["images"][frame_idx].to(device).float().unsqueeze(0) + vision_pos_enc = [ + torch.as_tensor(np.asarray(t[frame_idx]), device=device).float() for t in embeddings["pos_enc"] + ] + backbone_fpn = [ + torch.as_tensor(np.asarray(t[frame_idx]), device=device).float() for t in embeddings["fpn"] + ] + backbone_out = {"backbone_fpn": backbone_fpn, "vision_pos_enc": vision_pos_enc} + # Keep only the most recent frame's features, matching upstream SAM2's single-frame cache. + inference_state["cached_features"] = {frame_idx: (image, backbone_out)} + + # Expand the features to the number of objects being tracked (mirrors upstream SAM2). + expanded_image = image.expand(batch_size, -1, -1, -1) + expanded_backbone_out = { + "backbone_fpn": backbone_out["backbone_fpn"].copy(), + "vision_pos_enc": backbone_out["vision_pos_enc"].copy(), + } + for i, feat in enumerate(expanded_backbone_out["backbone_fpn"]): + expanded_backbone_out["backbone_fpn"][i] = feat.expand(batch_size, -1, -1, -1) + for i, pos in enumerate(expanded_backbone_out["vision_pos_enc"]): + expanded_backbone_out["vision_pos_enc"][i] = pos.expand(batch_size, -1, -1, -1) + + features = self._prepare_backbone_features(expanded_backbone_out) + features = (expanded_image,) + features + return features + def _build_sam2_video_predictor( config_file, diff --git a/micro_sam/v2/prompt_based_segmentation.py b/micro_sam/v2/prompt_based_segmentation.py index de290f29..d33bd5c1 100644 --- a/micro_sam/v2/prompt_based_segmentation.py +++ b/micro_sam/v2/prompt_based_segmentation.py @@ -226,10 +226,18 @@ def promptable_segmentation_3d( class PromptableSegmentation3D: """Promptable segmentation class for volumetric data. """ - def __init__(self, predictor, volume, volume_embeddings): + def __init__( + self, predictor, volume, volume_embeddings, device=None, + offload_video_to_cpu=True, offload_state_to_cpu=True, + ): self.predictor = predictor self.volume = volume self.volume_embeddings = volume_embeddings + # 'device=None' uses the predictor's auto-detected device. Offloading the frames and tracking + # state to CPU keeps GPU memory bounded for large volumes (a no-op when already on CPU). + self.device = device + self.offload_video_to_cpu = offload_video_to_cpu + self.offload_state_to_cpu = offload_state_to_cpu if self.volume.ndim != 3: raise AssertionError(f"The dimensionality of the volume should be 3, got '{self.volume.ndim}'") @@ -249,7 +257,10 @@ def __init__(self, predictor, volume, volume_embeddings): def init_predictor(self): # Initialize the inference state. - self.inference_state = self.predictor.init_state(volume=self.volume, volume_embeddings=self.volume_embeddings) + self.inference_state = self.predictor.init_state( + volume=self.volume, volume_embeddings=self.volume_embeddings, device=self.device, + offload_video_to_cpu=self.offload_video_to_cpu, offload_state_to_cpu=self.offload_state_to_cpu, + ) def reset_predictor(self): # Reset the state after finishing the segmentation round. @@ -466,13 +477,17 @@ def add_mask_prompts( ): raise NotImplementedError - def propagate_prompts(self): + def propagate_prompts(self, update_progress=None): # First, we propagate the masklets throughout the frames using the input prompts in selected frames. + # 'update_progress' is an optional callback that is called with the number of newly processed + # frames, so callers (e.g. the napari annotator) can report propagation progress to the user. forward_video_segments = {} for out_frame_idx, out_obj_ids, out_mask_logits in self.predictor.propagate_in_video(self.inference_state): forward_video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } + if update_progress is not None: + update_progress(1) # Next, we do the propagation reverse in time. reverse_video_segments = {} @@ -483,6 +498,8 @@ def propagate_prompts(self): reverse_video_segments[out_frame_idx] = { out_obj_id: (out_mask_logits[i] > 0.0).cpu().numpy() for i, out_obj_id in enumerate(out_obj_ids) } + if update_progress is not None: + update_progress(1) # NOTE: The order is reversed to stitch the reverse propagation with forward. reverse_video_segments = dict(reversed(list(reverse_video_segments.items()))) @@ -548,9 +565,9 @@ def segment_slice( return seg - def predict(self): + def predict(self, update_progress=None): # First, we propagate prompts. - video_segments = self.propagate_prompts() + video_segments = self.propagate_prompts(update_progress=update_progress) # Next, let's merge the segmented objects per frame back together as instances per slice. segmentation = [] From a0c659a9b21e075b28ae892a0241fda72508ecda Mon Sep 17 00:00:00 2001 From: Anwai Archit Date: Wed, 17 Jun 2026 14:57:47 +0200 Subject: [PATCH 5/5] Add SAM2 dep to pip yaml workflow to fix broken tests --- .github/workflows/test-pip-installation.yaml | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/.github/workflows/test-pip-installation.yaml b/.github/workflows/test-pip-installation.yaml index a3917e1b..e1702f5c 100644 --- a/.github/workflows/test-pip-installation.yaml +++ b/.github/workflows/test-pip-installation.yaml @@ -42,6 +42,10 @@ jobs: - name: Install MobileSAM run: python -m pip install git+https://github.com/ChaoningZhang/MobileSAM.git + # SAM2 (hvit_* backend) is not on PyPI; install it from GitHub to keep SAM2 coverage. + - name: Install SAM2 + run: python -m pip install git+https://github.com/facebookresearch/sam2.git + # A standalone mask prompt is build-dependent with the PyPI torch wheels used here; # MICRO_SAM_SKIP_STANDALONE_MASK_PROMPT skips that single assertion (see the test). # Set it via GITHUB_ENV so it is exported to the headless-gui step's environment.