Dear Authors,
First of all, thank you for your outstanding work on PerturbDiff! The concept and the methodology are truly inspiring.
I am currently trying to apply your pre-trained model (specifically the finetuned_tahoe100m_fixed.ckpt) to a novel, independent single-cell dataset (Colorectal Cancer data). My goal is to perform pure inference: simulating the knockout of a specific gene (e.g., TP53) on this unseen dataset, without having any actual ground-truth perturbation data or paired control cells.
My primary question is: Does the current theoretical framework and the pre-trained checkpoint support this kind of "zero-shot" simulation on a completely novel dataset?
While attempting to implement this inference pipeline, I ran into several engineering challenges. It seems that the current codebase is heavily optimized for training and reproducing benchmark metrics, which makes it quite difficult to decouple for pure, out-of-distribution inference. I would love to share a brief summary of the roadblocks I encountered, in hopes it might be helpful for future updates or an inference-only API:
Tight Coupling in DataLoader & Sampler: The dataset_core.py and sampler.py strictly expect paired "control" and "perturbed" cells to calculate metrics. Bypassing this to feed a simple .h5ad file of raw cells requires heavily modifying the dictionary mappings (e.g., grouped_num_cell, data_indices) to prevent KeyErrors and AssertionErrors.
Hardcoded Dataset Names & Metadata: The codebase heavily relies on predefined dataset names (pbmc, tahoe100m, etc.). When feeding novel data, the framework automatically assigns names like dummy_plate_9, which later causes AssertionErrors in functions like get_short_dsname and embedder mapping.
Strict Checkpoint Loading & Embedder Dimensions: When adapting the model to accept my dataset's dimensions (e.g., 2000 HVGs), modifying the nn.Linear layers causes Unexpected key(s) in load_state_dict because the checkpoint contains hardcoded dataset-specific embedders (e.g., x_embedder.pbmc.weight). This required setting strict=False to force initialization.
Shape Assertions in Diffusion Core: During the forward pass, the Transformer blocks often output a 3D tensor [Batch, 1, Dim], but the diffusion_core.py strictly asserts x_t.shape == eps.shape (expecting [Batch, Dim]). This required manual squeeze/reshape operations at the model's output to prevent runtime crashes.
I wanted to ask if you have any plans to release a simplified predict.py script for users who just want to input an .h5ad and a perturbation condition to get the simulated results.
Thank you again for your time, your amazing research, and for making this repository open-source!
Dear Authors,
First of all, thank you for your outstanding work on PerturbDiff! The concept and the methodology are truly inspiring.
I am currently trying to apply your pre-trained model (specifically the finetuned_tahoe100m_fixed.ckpt) to a novel, independent single-cell dataset (Colorectal Cancer data). My goal is to perform pure inference: simulating the knockout of a specific gene (e.g., TP53) on this unseen dataset, without having any actual ground-truth perturbation data or paired control cells.
My primary question is: Does the current theoretical framework and the pre-trained checkpoint support this kind of "zero-shot" simulation on a completely novel dataset?
While attempting to implement this inference pipeline, I ran into several engineering challenges. It seems that the current codebase is heavily optimized for training and reproducing benchmark metrics, which makes it quite difficult to decouple for pure, out-of-distribution inference. I would love to share a brief summary of the roadblocks I encountered, in hopes it might be helpful for future updates or an inference-only API:
Tight Coupling in DataLoader & Sampler: The dataset_core.py and sampler.py strictly expect paired "control" and "perturbed" cells to calculate metrics. Bypassing this to feed a simple .h5ad file of raw cells requires heavily modifying the dictionary mappings (e.g., grouped_num_cell, data_indices) to prevent KeyErrors and AssertionErrors.
Hardcoded Dataset Names & Metadata: The codebase heavily relies on predefined dataset names (pbmc, tahoe100m, etc.). When feeding novel data, the framework automatically assigns names like dummy_plate_9, which later causes AssertionErrors in functions like get_short_dsname and embedder mapping.
Strict Checkpoint Loading & Embedder Dimensions: When adapting the model to accept my dataset's dimensions (e.g., 2000 HVGs), modifying the nn.Linear layers causes Unexpected key(s) in load_state_dict because the checkpoint contains hardcoded dataset-specific embedders (e.g., x_embedder.pbmc.weight). This required setting strict=False to force initialization.
Shape Assertions in Diffusion Core: During the forward pass, the Transformer blocks often output a 3D tensor [Batch, 1, Dim], but the diffusion_core.py strictly asserts x_t.shape == eps.shape (expecting [Batch, Dim]). This required manual squeeze/reshape operations at the model's output to prevent runtime crashes.
I wanted to ask if you have any plans to release a simplified predict.py script for users who just want to input an .h5ad and a perturbation condition to get the simulated results.
Thank you again for your time, your amazing research, and for making this repository open-source!