Diff-Integrator is a JAX-accelerated optimization engine designed for integrative structural biology. It acts as the "orchestrator" that combines differentiable observables from diff-biophys into multi-objective loss functions.
By cleanly separating the optimization loop from the underlying biophysical kernels, diff-integrator enables robust, joint refinement of protein structures against diverse experimental data (e.g., SAXS, NMR Chemical Shifts, NMR RDCs) simultaneously.
The goal of diff-integrator is to provide a seamless optax-based refinement pipeline that handles:
- Multi-Objective Optimization: Easily weight and combine multiple experimental constraints via
JointLoss. - Abstract Parameterization: Optimize arbitrary parameter spaces—from Cartesian coordinates to internal backbone angles (phi/psi)—via user-defined
kinematics_fnmappers. - Dynamic Fitting: Analytically refit nuisance parameters (like Saupe alignment tensors or SAXS scaling factors) dynamically during gradient descent.
Experience Diff-Integrator directly in your browser with our Colab tutorials:
| Tutorial | Audience | Description | Action |
|---|---|---|---|
| 📉 Results Dashboard | Graduate / researcher | Visualizes the loss descent, Q-factors, chemical shift accuracy, NeRF drift, and a Cartesian vs. NeRF comparison across all four benchmarks (2KZV, GmR58A, HR2876B NeRF, HR2876B Cartesian). | |
| 🧪 Refinement Concepts | Student / researcher | Educational notebook explaining NMR observables, NeRF coordinate parameterization, RDC tensor degeneracy, and the fixed-tensor protocol. | |
| Reviewer / scientist | Honest quantitative assessment of the current method's failure modes: NeRF geometric drift, RDC overfitting on PEG data, and degrees-of-freedom imbalance. |
The core optimization engine. Built on optax (defaulting to the Adam optimizer), it manages the training loop, gradient calculation, and loss tracking.
- Abstract Support: Accepts arbitrary parameter sets (
init_params) and maps them to Cartesian space using an optionalkinematics_fn.
A container for combining multiple LossTerm objects. It computes the total weighted loss by evaluating each term on the current parameters and generated coordinates.
An abstract base class for defining differentiable constraints. All terms implement __call__(self, params, coords).
-
GeometryLoss: Implements basic structural priors, including harmonic restraints to a target Cartesian structure. -
SAXSLoss: Dynamically scales and fits theoretical SAXS profiles against experimental data using Debye kernels. -
FixedTensorRDCLoss: Fixed-tensor RDC loss that holds the Saupe alignment tensor frozen during backpropagation (usingjax.lax.stop_gradient) and re-fits it everyupdate_intervalepochs. Includescv_fractioncross-validation split andsuggested_weight()auto-scaling by overdetermination ratio. -
CAShiftLoss: Wraps the ring-current and secondary structure shift predictor to compute$C_\alpha$ chemical shift RMSDs from backbone torsion angles. -
CartesianCAShiftLoss: Cartesian-space variant that extracts φ/ψ on-the-fly from raw coordinates viacompute_phi_psi, enabling chemical shift refinement without a NeRF builder. -
BondLengthPenalty/BondAnglePenalty: Harmonic restraints on backbone bond lengths and angles to Engh & Huber ideal values. Used in Cartesian refinement to replace the hard geometric constraints of the NeRF builder. -
RamachandranLoss: Sequence-aware Ramachandran prior with residue-specific Gaussian wells. Handles GLY ε-basin (φ > 0) and PRO ring constraint correctly. -
NOELoss: Flat-bottomed harmonic NOE distance restraints (standard XPLOR/CNS convention). Acceptsatom_pairs,d_upper, and optionald_lowerarrays; reportscount_violations()andrms_violation()diagnostics. Usemake_noe_restraints()to map(res_id, atom_name)pairs directly to atom indices. -
ChiralityPenalty: Half-harmonic Cα chirality guard for Cartesian refinement. Prevents silent L→D inversion during gradient descent using the signed scalar triple productchi = dot(cross(N−CA, C−CA), C_prev−CA). Usemake_backbone_chirality(n_residues)for the standard N–CA–C layout.
from diff_integrator.loss import JointLoss
from diff_integrator.optimizer import EarlyStopping, IntegrativeRefiner
from diff_integrator.schedules import ExponentialDecaySchedule
from diff_integrator.terms.geometry import GeometryLoss
from diff_integrator.terms.nmr import FixedTensorRDCLoss, make_rdc_cv_refinement_fns
# 1. Build loss terms
geom_term = GeometryLoss(target_coords=starting_coords)
# FixedTensorRDCLoss holds the Saupe tensor fixed during backprop,
# preventing the degeneracy exploit that drives Q→0 unphysically.
loss_fn, q_eval, tensor_fn, val_q_fn, n_train, n_val = make_rdc_cv_refinement_fns(
rdc_res_ids, exp_rdcs, struct_res_ids, cv_fraction=0.2
)
rdc_term = FixedTensorRDCLoss(
loss_fn, tensor_fn, update_interval=50,
n_rdcs=n_train, val_q_eval_fn=val_q_fn
)
# Auto-scale weight by overdetermination ratio (ideal = 10×)
rdc_weight = rdc_term.suggested_weight(base_weight=1.0)
# 2. Combine into a joint loss
joint_loss = JointLoss([
(geom_term, 5.0),
(rdc_term, rdc_weight),
])
# 3. Annealed geometry weight: strong early, relaxed late
anchor_schedule = ExponentialDecaySchedule(
initial_weight=10.0, final_weight=0.1, decay_epochs=100
)
# 4. Refine — result is a RefinementResult dataclass
refiner = IntegrativeRefiner(loss_fn=joint_loss)
result = refiner.run(
init_params=starting_coords,
epochs=2000,
learning_rate=0.005,
weight_schedules={0: anchor_schedule}, # anneal geometry anchor
early_stopping=EarlyStopping( # stop when RDC Q plateaus
term_index=1, patience=50, min_delta=1e-4
),
)
print(f"Best checkpoint: epoch {result.best_epoch}")
print(f"Stopped early: {result.stopped_early} ({result.early_stopping_triggered_by})")
refined_coords = result.best_paramsFreeze experimental terms for an initial geometry-only phase, then thaw them for full joint refinement — without rebuilding the loss:
from diff_integrator.terms.chirality import make_backbone_chirality
from diff_integrator.terms.noe import make_noe_restraints
chirality_pen = make_backbone_chirality(n_residues)
noe_term = make_noe_restraints(noe_observations, struct_res_ids)
joint_loss = JointLoss([
(geom_term, 5.0), # 0 — position anchor
(bond_pen, 50.0), # 1 — bond lengths
(angle_pen, 10.0), # 2 — bond angles
(chirality_pen, 20.0), # 3 — chirality guard (always on)
(rdc_term, rdc_weight), # 4 — RDC
(noe_term, 5.0), # 5 — NOE
])
# Phase 1: geometry + chirality only
joint_loss.freeze_term(4) # freeze RDC
joint_loss.freeze_term(5) # freeze NOE
result_p1 = refiner.run(init_params=starting_coords, epochs=200)
# Phase 2: add experimental restraints
joint_loss.unfreeze_term(4)
joint_loss.unfreeze_term(5)
result = refiner.run(init_params=result_p1.final_params, epochs=1000)diff-integrator is being validated against several experimental NMR datasets:
-
2KZV (CvR118A): Joint refinement using
$C_\alpha$ Chemical Shifts and dual-medium (PAG/PEG) RDCs, lowering the$C_\alpha$ RMSD and bringing RDC Q-factors near zero. -
GmR58A & HR2876B: Successful gradient-based minimization of
$C_\alpha$ shift RMSD using internal coordinates (dihedrals). -
HR2876B Cartesian (Sprint 2): Cartesian + bond-geometry + chirality-guard refinement over 2000 epochs achieved 11× larger Cα RMSD improvement (−0.123 ppm vs −0.011 ppm) and 12× less structural drift (0.545 Å vs 6.4 Å) compared to the NeRF approach. RDC Q-factors dropped 63% on both alignment media (PEG: 0.440→0.163, Pf1: 0.443→0.162). The
ChiralityPenaltycorrected all 5 D-inverted Cα centers present in the raw PDB model 1 (the pre-Sprint-2 run had silently introduced a 6th).
diff-integrator/
├── diff_integrator/ # Core package
│ ├── loss.py # JointLoss and LossTerm interface
│ ├── optimizer.py # IntegrativeRefiner engine
│ └── terms/ # Concrete loss implementations
│ ├── geometry.py # Harmonic restraints, RMSD
│ ├── bond_geometry.py # Cartesian bond/angle penalties (Engh & Huber)
│ ├── chirality.py # Cα chirality penalty (L→D inversion guard)
│ ├── saxs.py # Debye scattering loss
│ ├── nmr.py # RDC and Q-factor loss
│ ├── noe.py # NOE flat-bottomed distance restraints
│ └── chemical_shifts.py # C-alpha shift loss
├── benchmarks/ # Real-world optimization tests
├── tests/ # Unit tests (100% coverage)
├── docs/ # MkDocs documentation
└── pyproject.toml # Build configuration
Ensure you have JAX installed, then install diff-integrator locally:
pip install -e .Contributions are welcome! Please run the test suite and ensure mypy typing passes before submitting PRs:
pytest --cov=diff_integrator
mypy .MIT License — see LICENSE for details.