Skip to content

feat(modconv): add conditional modulated convolutions for MRI reconstruction#309

Open
georgeyiasemis wants to merge 13 commits into
mainfrom
modulated-convolution
Open

feat(modconv): add conditional modulated convolutions for MRI reconstruction#309
georgeyiasemis wants to merge 13 commits into
mainfrom
modulated-convolution

Conversation

@georgeyiasemis

Copy link
Copy Markdown
Contributor

Description

This PR ports conditional learned reconstruction into DIRECT by introducing modulated convolutions — convolutional layers whose weights are adapted by small MLPs from acquisition metadata (acceleration factor, ACS fraction, field strength).

The implementation follows Conditional Learned Reconstruction for Medical Imaging (Moriakov et al., MIDL 2026; OpenReview PDF). A single trained backbone can be conditioned at inference time on the actual undersampling pattern instead of training separate models per acceleration.

What's New

Modulated convolution package (direct/nn/conv/modulated/)

  • ModConv2d / ModConv3d and transposed variants with modulation types: NONE, FEATURES, FULL, PARTIAL_IN, PARTIAL_OUT, SUM
  • auxiliary_data.py — registry-based auxiliary feature pipeline:
    • prepare_auxiliary_data(data, cfg) builds (batch, aux_in_features) conditioning vectors
    • register_auxiliary_feature() for custom conditioning channels
    • Default features: acceleration, center_fraction, field_strength
  • AdaIN2d / AdaIN3d adaptive instance normalization modules (direct/nn/adain/)

Auxiliary conditioning pipeline

  • CreateSamplingMask requests return_acceleration from BaseMaskFunc masks
  • Batch keys: acceleration, center_fraction
  • MRIModelEngine._attach_auxiliary_data() attaches auxiliary_data each iteration (supervised, SSL, JSSL)
  • FastMRI datasets expose field_strength (1.5 T / 3.0 T from filename)

Triangular acceleration sampling

  • linear_range: true in masking config samples acceleration from a triangular distribution (paper Section 4.3.2)
  • direct/utils/distributions.pytriangular_distribution()
  • BaseMaskFunc returns sampled acceleration metadata via return_acceleration

Model support

Modulated convolutions wired end-to-end through conv-based models:

Model Config field Engine passes auxiliary_data
vSHARP (2D/3D) conv_modulation Yes
VarNet conv_modulation Yes
XPDNet conv_modulation Yes
KIKINet conv_modulation Yes
LPD conv_modulation Yes
JointICNet conv_modulation Yes
IterDualNet conv_modulation Yes
Unet2d conv_modulation Yes
Conv2d, DIDN, MWCNN modulation (layer-level) Via parent model

Shared config fields: conv_modulation, aux_in_features, auxiliary_features, log_aux, fc_hidden_features, fc_activation, fc_groups, num_weights

Documentation & configs

  • projects/modulated_convolution/README.rst — tutorial with paper figures and section references
  • Example configs for vSHARP (knee/prostate) and VarNet (prostate) with triangular acceleration sampling

What's Changed

Masking (direct/common/subsample.py)

  • Unified _draw_acceleration_value() / _draw_acceleration_pair() for uniform_range and linear_range
  • Centralized return_acceleration in BaseMaskFunc.__call__
  • Bug fix: equispaced masks early-return when ACS already covers target acceleration (high <= 0 guard)
  • Bug fix: CartesianMagicMaskFunc integer center line counts no longer capped by min(1/accel, cf); ACS-only edge case handled
  • Bug fix: corrected len(center_fractions) != len(accelerations) validation

U-Net backbone

  • Unet2d / Unet3d / NormUnet* swap Conv2dModConv2d when conv_modulation != NONE
  • Modulated transposed convolutions in decoder blocks

Paper Mapping

Paper (Sec. / Eq.) DIRECT
Eq. 6 — modulated convolution ModConv2d with FEATURES type
Eq. 7 — z = log([R, 100·r_acs]) log_aux: true + acceleration / center_fraction batch keys
Sec. 4.3.2 — triangular R ∈ [4,16] linear_range: true in masking config
MOD S/M/L — MLP [32,8] / [32,16] / [32,32] fc_hidden_features in model config

Related

georgeyiasemis and others added 8 commits June 15, 2026 16:23
- Add ModConv2d/3d and ModConvTranspose2d/3d with multiple modulation types
- Add AdaIN2d/AdaIN3d adaptive instance normalization modules
- Integrate modulated convolutions into UNet2d/3d and VSharpNet/3D
- Add IntOrTuple type alias
- Fix SyntaxWarnings from invalid escape sequences in docstrings
- Fix DeprecationWarning in engine.py (numpy/torch interop)
- Fix UserWarning in gradloss_test.py (tensor from list of ndarrays)

Co-authored-by: Cursor <cursoragent@cursor.com>
Wire conditional modulation through VarNet, KIKINet, JointICNet, IterDualNet,
and LPD, plus Conv2d, DIDN, and MWCNN backbones. Add modulated conv unit tests,
publication references in docstrings, and modulation MLP layout fix. Apply
black/isort formatting across the codebase.

Co-authored-by: Cursor <cursoragent@cursor.com>
Introduce triangular acceleration sampling for training configs, return
sampled acceleration metadata from BaseMaskFunc, and fix integer center
line counts and ACS-only edge cases in MagicMaskFunc.

Co-authored-by: Cursor <cursoragent@cursor.com>
Reorganize modulated convolution layers into a subpackage and add a
central auxiliary-data registry used by conditional reconstruction models.

Co-authored-by: Cursor <cursoragent@cursor.com>
Expose sampled acceleration and center fraction in batches, attach
auxiliary tensors in MRIModelEngine, and register field strength in
FastMRI datasets for configurable conditioning features.

Co-authored-by: Cursor <cursoragent@cursor.com>
…odels

Enable auxiliary conditioning in VarNet, vSHARP, XPDNet, KIKINet, LPD,
JointICNet, IterDualNet, and Unet2d engines and align model configs
with the shared conv_modulation settings.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add a project README and knee/prostate training configs demonstrating
feature-based modulated convolutions with triangular acceleration.

Co-authored-by: Cursor <cursoragent@cursor.com>
Embed architecture and result figures from the MIDL paper, map config
fields to Eq. 7 and Section 3.1, and link to the OpenReview PDF.

Co-authored-by: Cursor <cursoragent@cursor.com>
@codecov

codecov Bot commented Jun 19, 2026

Copy link
Copy Markdown

Codecov Report

❌ Patch coverage is 79.73471% with 275 lines in your changes missing coverage. Please review.
✅ Project coverage is 84.71%. Comparing base (f83b7bb) to head (b4372b9).

Files with missing lines Patch % Lines
direct/nn/conv/modulated/modulated_conv.py 75.11% 106 Missing ⚠️
direct/nn/adain/adain.py 18.42% 62 Missing ⚠️
direct/nn/unet/unet_3d.py 71.87% 27 Missing ⚠️
direct/nn/unet/unet_2d.py 81.65% 20 Missing ⚠️
direct/nn/vsharp/vsharp.py 76.56% 15 Missing ⚠️
direct/common/subsample.py 84.21% 12 Missing ⚠️
direct/nn/crossdomain/multicoil.py 60.00% 4 Missing ⚠️
direct/nn/didn/didn.py 96.03% 4 Missing ⚠️
direct/nn/iterdualnet/iterdualnet.py 83.33% 3 Missing ⚠️
direct/nn/jointicnet/jointicnet.py 81.25% 3 Missing ⚠️
... and 12 more
Additional details and impacted files
@@            Coverage Diff             @@
##             main     #309      +/-   ##
==========================================
- Coverage   85.73%   84.71%   -1.03%     
==========================================
  Files         103      110       +7     
  Lines        9041    10118    +1077     
==========================================
+ Hits         7751     8571     +820     
- Misses       1290     1547     +257     

☔ View full report in Codecov by Harness.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.
  • 📦 JS Bundle Analysis: Save yourself from yourself by tracking and limiting bundle sizes in JS merges.

georgeyiasemis and others added 5 commits June 19, 2026 17:45
Add typed modulated-conv factories, normalize optional modulation types, and apply formatting fixes so type checking and tests pass on the modulated-convolution branch.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add .prospector.yml so Codacy skips conflicting pydocstyle rules (D203/D213 etc.) that Ruff does not enforce; fix unused loop variable in DIDN ReconBlock.

Co-authored-by: Cursor <cursoragent@cursor.com>
Disable pydocstyle and noisy pylint limits that conflict with numpydoc-style
docstrings and typical model __init__ signatures, and fix an unused loop variable.

Co-authored-by: Cursor <cursoragent@cursor.com>
Assign acceleration fields in the same branch where they are unpacked, and
wrap long DIDN bibliography lines to satisfy line-length checks.

Co-authored-by: Cursor <cursoragent@cursor.com>
Add factory, modulation-type, validation, UNet/DIDN, auxiliary-data, and
CreateSamplingMask acceleration tests to improve coverage on new modconv code.

Co-authored-by: Cursor <cursoragent@cursor.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant