Skip to content

elkins-lab/diff-integrator

⚙️ Diff-Integrator: The Integrative Refinement Engine

Tests codecov PyPI version Python versions License: MIT JAX Ruff Type checked: mypy

Diff-Integrator is a JAX-accelerated optimization engine designed for integrative structural biology. It acts as the "orchestrator" that combines differentiable observables from diff-biophys into multi-objective loss functions.

By cleanly separating the optimization loop from the underlying biophysical kernels, diff-integrator enables robust, joint refinement of protein structures against diverse experimental data (e.g., SAXS, NMR Chemical Shifts, NMR RDCs) simultaneously.


🎯 Vision

The goal of diff-integrator is to provide a seamless optax-based refinement pipeline that handles:

  1. Multi-Objective Optimization: Easily weight and combine multiple experimental constraints via JointLoss.
  2. Abstract Parameterization: Optimize arbitrary parameter spaces—from Cartesian coordinates to internal backbone angles (phi/psi)—via user-defined kinematics_fn mappers.
  3. Dynamic Fitting: Analytically refit nuisance parameters (like Saupe alignment tensors or SAXS scaling factors) dynamically during gradient descent.

📚 Interactive Tutorials

Experience Diff-Integrator directly in your browser with our Colab tutorials:

Tutorial Audience Description Action
📉 Results Dashboard Graduate / researcher Visualizes the loss descent, Q-factors, chemical shift accuracy, NeRF drift, and a Cartesian vs. NeRF comparison across all four benchmarks (2KZV, GmR58A, HR2876B NeRF, HR2876B Cartesian). Open In Colab
🧪 Refinement Concepts Student / researcher Educational notebook explaining NMR observables, NeRF coordinate parameterization, RDC tensor degeneracy, and the fixed-tensor protocol. Open In Colab
⚠️ Method Limitations Reviewer / scientist Honest quantitative assessment of the current method's failure modes: NeRF geometric drift, RDC overfitting on PEG data, and degrees-of-freedom imbalance. Open In Colab

⚡ Core Components

IntegrativeRefiner

The core optimization engine. Built on optax (defaulting to the Adam optimizer), it manages the training loop, gradient calculation, and loss tracking.

  • Abstract Support: Accepts arbitrary parameter sets (init_params) and maps them to Cartesian space using an optional kinematics_fn.

JointLoss

A container for combining multiple LossTerm objects. It computes the total weighted loss by evaluating each term on the current parameters and generated coordinates.

LossTerm (Interface)

An abstract base class for defining differentiable constraints. All terms implement __call__(self, params, coords).

Included Observables

  • GeometryLoss: Implements basic structural priors, including harmonic restraints to a target Cartesian structure.
  • SAXSLoss: Dynamically scales and fits theoretical SAXS profiles against experimental data using Debye kernels.
  • FixedTensorRDCLoss: Fixed-tensor RDC loss that holds the Saupe alignment tensor frozen during backpropagation (using jax.lax.stop_gradient) and re-fits it every update_interval epochs. Includes cv_fraction cross-validation split and suggested_weight() auto-scaling by overdetermination ratio.
  • CAShiftLoss: Wraps the ring-current and secondary structure shift predictor to compute $C_\alpha$ chemical shift RMSDs from backbone torsion angles.
  • CartesianCAShiftLoss: Cartesian-space variant that extracts φ/ψ on-the-fly from raw coordinates via compute_phi_psi, enabling chemical shift refinement without a NeRF builder.
  • BondLengthPenalty / BondAnglePenalty: Harmonic restraints on backbone bond lengths and angles to Engh & Huber ideal values. Used in Cartesian refinement to replace the hard geometric constraints of the NeRF builder.
  • RamachandranLoss: Sequence-aware Ramachandran prior with residue-specific Gaussian wells. Handles GLY ε-basin (φ > 0) and PRO ring constraint correctly.
  • NOELoss: Flat-bottomed harmonic NOE distance restraints (standard XPLOR/CNS convention). Accepts atom_pairs, d_upper, and optional d_lower arrays; reports count_violations() and rms_violation() diagnostics. Use make_noe_restraints() to map (res_id, atom_name) pairs directly to atom indices.
  • ChiralityPenalty: Half-harmonic Cα chirality guard for Cartesian refinement. Prevents silent L→D inversion during gradient descent using the signed scalar triple product chi = dot(cross(N−CA, C−CA), C_prev−CA). Use make_backbone_chirality(n_residues) for the standard N–CA–C layout.

🚀 Usage Example

from diff_integrator.loss import JointLoss
from diff_integrator.optimizer import EarlyStopping, IntegrativeRefiner
from diff_integrator.schedules import ExponentialDecaySchedule
from diff_integrator.terms.geometry import GeometryLoss
from diff_integrator.terms.nmr import FixedTensorRDCLoss, make_rdc_cv_refinement_fns

# 1. Build loss terms
geom_term = GeometryLoss(target_coords=starting_coords)

# FixedTensorRDCLoss holds the Saupe tensor fixed during backprop,
# preventing the degeneracy exploit that drives Q→0 unphysically.
loss_fn, q_eval, tensor_fn, val_q_fn, n_train, n_val = make_rdc_cv_refinement_fns(
    rdc_res_ids, exp_rdcs, struct_res_ids, cv_fraction=0.2
)
rdc_term = FixedTensorRDCLoss(
    loss_fn, tensor_fn, update_interval=50,
    n_rdcs=n_train, val_q_eval_fn=val_q_fn
)
# Auto-scale weight by overdetermination ratio (ideal = 10×)
rdc_weight = rdc_term.suggested_weight(base_weight=1.0)

# 2. Combine into a joint loss
joint_loss = JointLoss([
    (geom_term, 5.0),
    (rdc_term, rdc_weight),
])

# 3. Annealed geometry weight: strong early, relaxed late
anchor_schedule = ExponentialDecaySchedule(
    initial_weight=10.0, final_weight=0.1, decay_epochs=100
)

# 4. Refine — result is a RefinementResult dataclass
refiner = IntegrativeRefiner(loss_fn=joint_loss)
result = refiner.run(
    init_params=starting_coords,
    epochs=2000,
    learning_rate=0.005,
    weight_schedules={0: anchor_schedule},      # anneal geometry anchor
    early_stopping=EarlyStopping(              # stop when RDC Q plateaus
        term_index=1, patience=50, min_delta=1e-4
    ),
)
print(f"Best checkpoint: epoch {result.best_epoch}")
print(f"Stopped early:   {result.stopped_early} ({result.early_stopping_triggered_by})")
refined_coords = result.best_params

Multi-phase refinement with freeze_term

Freeze experimental terms for an initial geometry-only phase, then thaw them for full joint refinement — without rebuilding the loss:

from diff_integrator.terms.chirality import make_backbone_chirality
from diff_integrator.terms.noe import make_noe_restraints

chirality_pen = make_backbone_chirality(n_residues)
noe_term      = make_noe_restraints(noe_observations, struct_res_ids)

joint_loss = JointLoss([
    (geom_term,      5.0),   # 0 — position anchor
    (bond_pen,      50.0),   # 1 — bond lengths
    (angle_pen,     10.0),   # 2 — bond angles
    (chirality_pen, 20.0),   # 3 — chirality guard (always on)
    (rdc_term,  rdc_weight), # 4 — RDC
    (noe_term,       5.0),   # 5 — NOE
])

# Phase 1: geometry + chirality only
joint_loss.freeze_term(4)   # freeze RDC
joint_loss.freeze_term(5)   # freeze NOE
result_p1 = refiner.run(init_params=starting_coords, epochs=200)

# Phase 2: add experimental restraints
joint_loss.unfreeze_term(4)
joint_loss.unfreeze_term(5)
result = refiner.run(init_params=result_p1.final_params, epochs=1000)

🔬 Scientific Validation

diff-integrator is being validated against several experimental NMR datasets:

  • 2KZV (CvR118A): Joint refinement using $C_\alpha$ Chemical Shifts and dual-medium (PAG/PEG) RDCs, lowering the $C_\alpha$ RMSD and bringing RDC Q-factors near zero.
  • GmR58A & HR2876B: Successful gradient-based minimization of $C_\alpha$ shift RMSD using internal coordinates (dihedrals).
  • HR2876B Cartesian (Sprint 2): Cartesian + bond-geometry + chirality-guard refinement over 2000 epochs achieved 11× larger Cα RMSD improvement (−0.123 ppm vs −0.011 ppm) and 12× less structural drift (0.545 Å vs 6.4 Å) compared to the NeRF approach. RDC Q-factors dropped 63% on both alignment media (PEG: 0.440→0.163, Pf1: 0.443→0.162). The ChiralityPenalty corrected all 5 D-inverted Cα centers present in the raw PDB model 1 (the pre-Sprint-2 run had silently introduced a 6th).

📂 Repository Structure

diff-integrator/
├── diff_integrator/       # Core package
│   ├── loss.py            # JointLoss and LossTerm interface
│   ├── optimizer.py       # IntegrativeRefiner engine
│   └── terms/             # Concrete loss implementations
│       ├── geometry.py    # Harmonic restraints, RMSD
│       ├── bond_geometry.py # Cartesian bond/angle penalties (Engh & Huber)
│       ├── chirality.py   # Cα chirality penalty (L→D inversion guard)
│       ├── saxs.py        # Debye scattering loss
│       ├── nmr.py         # RDC and Q-factor loss
│       ├── noe.py         # NOE flat-bottomed distance restraints
│       └── chemical_shifts.py # C-alpha shift loss
├── benchmarks/            # Real-world optimization tests
├── tests/                 # Unit tests (100% coverage)
├── docs/                  # MkDocs documentation
└── pyproject.toml         # Build configuration

🚀 Installation

Ensure you have JAX installed, then install diff-integrator locally:

pip install -e .

🤝 Contributing

Contributions are welcome! Please run the test suite and ensure mypy typing passes before submitting PRs:

pytest --cov=diff_integrator
mypy .

⚖️ License

MIT License — see LICENSE for details.