Skip to content

Commit ae036e3

Browse files
Add RFpeptides code
1 parent b44206a commit ae036e3

7 files changed

Lines changed: 143 additions & 18 deletions

File tree

README.md

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ RFdiffusion is an open source method for structure generation, with or without c
4545
- [Generation of Symmetric Oligomers](#generation-of-symmetric-oligomers)
4646
- [Using Auxiliary Potentials](#using-auxiliary-potentials)
4747
- [Symmetric Motif Scaffolding.](#symmetric-motif-scaffolding)
48+
- [RFpeptides macrocycle design](#macrocyclic-peptide-design-with-rfpeptides)
4849
- [A Note on Model Weights](#a-note-on-model-weights)
4950
- [Things you might want to play with at inference time](#things-you-might-want-to-play-with-at-inference-time)
5051
- [Understanding the output files](#understanding-the-output-files)
@@ -466,6 +467,51 @@ Note that the contigs should specify something that is precisely symmetric. Thin
466467

467468
---
468469

470+
### Macrocyclic peptide design with RFpeptides Add commentMore actions
471+
<img src="./img/rfpeptides_fig1.png" alt="alt text" width="400px" align="right"/>
472+
We have recently published the RFpeptides protocol for using RFdiffusion to design macrocyclic peptides that bind target proteins with atomic accuracy (Rettie, Juergens, Adebomi, et al., 2025). In this section we briefly outline how to run this inference protocol. We have added two examples for running macrocycle design with the RFpeptides protocol. One for monomeric design, and one for binder design.
473+
474+
NOTE: Until the pull request is merged, you can find this code in the branch `rfpeptides`.
475+
476+
```
477+
examples/design_macrocyclic_monomer.sh
478+
examples/design_macrocyclic_binder.sh
479+
```
480+
#### RFpeptides binder design
481+
<img src="./img/rfpeptides_binder.png" alt="alt text" width="1100" align="center"/>
482+
483+
To design a macrocyclic peptide to bind a target, the flags needed are very similar to classic binder design, but with two additional flags:
484+
```
485+
#!/bin/bash
486+
487+
prefix=./outputs/diffused_binder_cyclic2
488+
489+
# Note that the indices in this pdb file have been
490+
# shifted by +2 in chain A relative to pdbID 7zkr.
491+
pdb='./input_pdbs/7zkr_GABARAP.pdb'
492+
493+
num_designs=10
494+
script="../scripts/run_inference.py"
495+
$script --config-name base \
496+
inference.output_prefix=$prefix \
497+
inference.num_designs=$num_designs \
498+
'contigmap.contigs=[12-18 A3-117/0]' \
499+
inference.input_pdb=$pdb \
500+
inference.cyclic=True \
501+
diffuser.T=50 \
502+
inference.cyc_chains='a' \
503+
ppi.hotspot_res=[\'A51\',\'A52\',\'A50\',\'A48\',\'A62\',\'A65\'] \
504+
```
505+
506+
The new flags are `inference.cyclic=True` and `inference.cyc_chains`. Yes, they are somewhat redundant.
507+
508+
`inference.cyclic` simply notifies the program that the user would like to design at least one macrocycle, and `inference.cyc_chains` is just a string containing the letter of every chain you would like to design as a cyclic peptide. In the example above, only chain `A` (`inference.cyc_chains='a'`) is cyclized, but one could do `inference.cyc_chains='abcd'` if they so desired (and the contigs was compatible with this, which the above one is not).
509+
510+
#### RFpeptides monomer design
511+
For monomer design, you can simply adjust the contigs to only contain a single generated chain e.g., `contigmap.contigs=[12-18]`, keep the `inference.cyclic=True` and `inference.cyc_chains='a'`, and you're off to the races making monomers.
512+
513+
---
514+
469515
### A Note on Model Weights
470516

471517
Because of everything we want diffusion to be able to do, there is not *One Model To Rule Them All*. E.g., if you want to run with secondary structure conditioning, this requires a different model than if you don't. Under the hood, we take care of most of this by default - we parse your input and work out the most appropriate checkpoint.

config/inference/base.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ inference:
2121
trb_save_ckpt_path: null
2222
schedule_directory_path: null
2323
model_directory_path: null
24+
cyclic: False
25+
cyc_chains: 'a'
2426

2527
contigmap:
2628
contigs: null

rfdiffusion/Embeddings.py

Lines changed: 28 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from opt_einsum import contract as einsum
55
import torch.utils.checkpoint as checkpoint
66
from rfdiffusion.util import get_tips
7-
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal
7+
from rfdiffusion.util_module import Dropout, create_custom_forward, rbf, init_lecun_normal, find_breaks
88
from rfdiffusion.Attention_module import Attention, FeedForwardLayer, AttentionWithBias
99
from rfdiffusion.Track_module import PairStr2Pair
1010
import math
@@ -21,10 +21,34 @@ def __init__(self, d_model, minpos=-32, maxpos=32, p_drop=0.1):
2121
self.emb = nn.Embedding(self.nbin, d_model)
2222
self.drop = nn.Dropout(p_drop)
2323

24-
def forward(self, x, idx):
24+
def forward(self, x, idx, cyclize=None):
2525
bins = torch.arange(self.minpos, self.maxpos, device=x.device)
2626
seqsep = idx[:,None,:] - idx[:,:,None] # (B, L, L)
2727
#
28+
29+
30+
# adding support for multi-chain cyclic
31+
# find chain breaks and label chain ids
32+
breaks = find_breaks(idx.squeeze().cpu().numpy(), thresh=35) # NOTE: Hard coded threshold for defining chain breaks here
33+
# Typical jump for chainbreaks is +200
34+
# Assumes monotonically increasing absolute IDX
35+
36+
chainids = np.zeros_like(idx.squeeze().cpu().numpy())
37+
for i, b in enumerate(breaks):
38+
chainids[b:] = i+1
39+
chainids = torch.from_numpy(chainids).to(device=idx.device)
40+
41+
# cyclic peptide
42+
if cyclize is not None:
43+
for chid in torch.unique(chainids):
44+
is_chid = chainids==chid
45+
cur_cyclize = cyclize*is_chid
46+
cur_mask = cur_cyclize[:,None]*cur_cyclize[None,:] # (L,L)
47+
cur_ncyc = torch.sum(cur_cyclize)
48+
49+
seqsep[:,cur_mask*(seqsep[0]>cur_ncyc//2)] -= cur_ncyc
50+
seqsep[:,cur_mask*(seqsep[0]<-cur_ncyc//2)] += cur_ncyc
51+
2852
ib = torch.bucketize(seqsep, bins).long() # (B, L, L)
2953
emb = self.emb(ib) #(B, L, L, d_model)
3054
x = x + emb # add relative positional encoding
@@ -56,7 +80,7 @@ def reset_parameter(self):
5680

5781
nn.init.zeros_(self.emb.bias)
5882

59-
def forward(self, msa, seq, idx):
83+
def forward(self, msa, seq, idx, cyclize):
6084
# Inputs:
6185
# - msa: Input MSA (B, N, L, d_init)
6286
# - seq: Input Sequence (B, L)
@@ -82,7 +106,7 @@ def forward(self, msa, seq, idx):
82106
right = (seq @ self.emb_right.weight)[:,:,None] # (B, L, 1, d_pair)
83107

84108
pair = left + right # (B, L, L, d_pair)
85-
pair = self.pos(pair, idx) # add relative position
109+
pair = self.pos(pair, idx, cyclize) # add relative position
86110

87111
# state embedding
88112
# Sergey's one hot trick

rfdiffusion/RoseTTAFoldModel.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,12 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
6969
t1d=None, t2d=None, xyz_t=None, alpha_t=None,
7070
msa_prev=None, pair_prev=None, state_prev=None,
7171
return_raw=False, return_full=False, return_infer=False,
72-
use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None):
72+
use_checkpoint=False, motif_mask=None, i_cycle=None, n_cycle=None,
73+
cyclic_reses=None):
7374

7475
B, N, L = msa_latent.shape[:3]
7576
# Get embeddings
76-
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx)
77+
msa_latent, pair, state = self.latent_emb(msa_latent, seq, idx, cyclic_reses)
7778
msa_full = self.full_emb(msa_full, seq, idx)
7879

7980
# Do recycling
@@ -101,7 +102,7 @@ def forward(self, msa_latent, msa_full, seq, xyz, idx, t,
101102
is_frozen_residue = motif_mask if self.freeze_track_motif else torch.zeros_like(motif_mask).bool()
102103
msa, pair, R, T, alpha_s, state = self.simulator(seq, msa_latent, msa_full, pair, xyz[:,:,:3],
103104
state, idx, use_checkpoint=use_checkpoint,
104-
motif_mask=is_frozen_residue)
105+
motif_mask=is_frozen_residue, cyclic_reses=cyclic_reses)
105106

106107
if return_raw:
107108
# get last structure

rfdiffusion/Track_module.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,7 @@ def reset_parameter(self):
234234
nn.init.zeros_(self.embed_e2.bias)
235235

236236
@torch.cuda.amp.autocast(enabled=False)
237-
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64, eps=1e-5):
237+
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses=None, top_k=64, eps=1e-5):
238238
B, N, L = msa.shape[:3]
239239

240240
if motif_mask is None:
@@ -249,7 +249,7 @@ def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, top_k=64,
249249
node = self.norm_node(self.embed_x(node))
250250
pair = self.norm_edge1(self.embed_e1(pair))
251251

252-
neighbor = get_seqsep(idx)
252+
neighbor = get_seqsep(idx, cyclic_reses)
253253
rbf_feat = rbf(torch.cdist(xyz[:,:,1], xyz[:,:,1]))
254254
pair = torch.cat((pair, rbf_feat, neighbor), dim=-1)
255255
pair = self.norm_edge2(self.embed_e2(pair))
@@ -318,18 +318,18 @@ def __init__(self, d_msa=256, d_pair=128,
318318
SE3_param=SE3_param,
319319
p_drop=p_drop)
320320

321-
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False):
321+
def forward(self, msa, pair, R_in, T_in, xyz, state, idx, motif_mask, use_checkpoint=False, cyclic_reses=None):
322322
rbf_feat = rbf(torch.cdist(xyz[:,:,1,:], xyz[:,:,1,:]))
323323
if use_checkpoint:
324324
msa = checkpoint.checkpoint(create_custom_forward(self.msa2msa), msa, pair, rbf_feat, state)
325325
pair = checkpoint.checkpoint(create_custom_forward(self.msa2pair), msa, pair)
326326
pair = checkpoint.checkpoint(create_custom_forward(self.pair2pair), pair, rbf_feat)
327-
R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask)
327+
R, T, state, alpha = checkpoint.checkpoint(create_custom_forward(self.str2str, top_k=0), msa, pair, R_in, T_in, xyz, state, idx, motif_mask, cyclic_reses)
328328
else:
329329
msa = self.msa2msa(msa, pair, rbf_feat, state)
330330
pair = self.msa2pair(msa, pair)
331331
pair = self.pair2pair(pair, rbf_feat)
332-
R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, top_k=0)
332+
R, T, state, alpha = self.str2str(msa, pair, R_in, T_in, xyz, state, idx, motif_mask=motif_mask, cyclic_reses=cyclic_reses, top_k=0)
333333

334334
return msa, pair, R, T, state, alpha
335335

@@ -384,7 +384,7 @@ def reset_parameter(self):
384384
self.proj_state2 = init_lecun_normal(self.proj_state2)
385385
nn.init.zeros_(self.proj_state2.bias)
386386

387-
def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=False, motif_mask=None):
387+
def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, cyclic_reses=None, use_checkpoint=False, motif_mask=None):
388388
"""
389389
input:
390390
seq: query sequence (B, L)
@@ -425,7 +425,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
425425
state,
426426
idx,
427427
motif_mask=motif_mask,
428-
use_checkpoint=use_checkpoint)
428+
use_checkpoint=use_checkpoint,
429+
cyclic_reses=cyclic_reses)
429430
R_s.append(R_in)
430431
T_s.append(T_in)
431432
alpha_s.append(alpha)
@@ -444,7 +445,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
444445
state,
445446
idx,
446447
motif_mask=motif_mask,
447-
use_checkpoint=use_checkpoint)
448+
use_checkpoint=use_checkpoint,
449+
cyclic_reses=cyclic_reses)
448450
R_s.append(R_in)
449451
T_s.append(T_in)
450452
alpha_s.append(alpha)
@@ -462,7 +464,8 @@ def forward(self, seq, msa, msa_full, pair, xyz_in, state, idx, use_checkpoint=F
462464
state,
463465
idx,
464466
top_k=64,
465-
motif_mask=motif_mask)
467+
motif_mask=motif_mask,
468+
cyclic_reses=cyclic_reses)
466469
R_s.append(R_in)
467470
T_s.append(T_in)
468471
alpha_s.append(alpha)

rfdiffusion/inference/model_runners.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,29 @@ def sample_init(self, return_forward_trajectory=False):
278278
self.mappings = self.contig_map.get_mappings()
279279
self.mask_seq = torch.from_numpy(self.contig_map.inpaint_seq)[None,:]
280280
self.mask_str = torch.from_numpy(self.contig_map.inpaint_str)[None,:]
281-
self.binderlen = len(self.contig_map.inpaint)
281+
self.binderlen = len(self.contig_map.inpaint)
282+
283+
#######################################
284+
### Resolve cyclic peptide indicies ###
285+
#######################################
286+
if self._conf.inference.cyclic:
287+
if self._conf.inference.cyc_chains is None:
288+
# default to all residues being cyclized
289+
self.cyclic_reses = ~self.mask_str.to(self.device).squeeze()
290+
else:
291+
# use cyc_chains arg to determine cyclic_reses mask
292+
assert type(self._conf.inference.cyc_chains) is str, 'cyc_chains arg must be string'
293+
cyc_chains = self._conf.inference.cyc_chains
294+
cyc_chains = [i.upper() for i in cyc_chains]
295+
hal_idx = self.contig_map.hal # the pdb indices of output, knowledge of different chains
296+
is_cyclized = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze() # initially empty
297+
298+
for ch in cyc_chains:
299+
ch_mask = torch.tensor([idx[0] == ch for idx in hal_idx]).bool()
300+
is_cyclized[ch_mask] = True # set this whole chain to be cyclic
301+
self.cyclic_reses = is_cyclized
302+
else:
303+
self.cyclic_reses = torch.zeros_like(self.mask_str).bool().to(self.device).squeeze()
282304

283305
####################
284306
### Get Hotspots ###
@@ -675,7 +697,8 @@ def sample_step(self, *, t, x_t, seq_init, final_step):
675697
state_prev = None,
676698
t=torch.tensor(t),
677699
return_infer=True,
678-
motif_mask=self.diffusion_mask.squeeze().to(self.device))
700+
motif_mask=self.diffusion_mask.squeeze().to(self.device),
701+
cyclic_reses=self.cyclic_reses)
679702

680703
if self.symmetry is not None and self.inf_conf.symmetric_self_cond:
681704
px0 = self.symmetrise_prev_pred(px0=px0,seq_in=seq_in, alpha=alpha)[:,:,:3]

rfdiffusion/util_module.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,13 @@
77
import dgl
88
from rfdiffusion.util import base_indices, RTs_by_torsion, xyzs_in_base_frame, rigid_from_3_points
99

10+
11+
def find_breaks(ix, thresh=35):
12+
# finds positions in ix where the jump is greater than 100
13+
breaks = np.where(np.diff(ix) > thresh)[0]
14+
return np.array(breaks)+1
15+
16+
1017
def init_lecun_normal(module):
1118
def truncated_normal(uniform, mu=0.0, sigma=1.0, a=-2, b=2):
1219
normal = torch.distributions.normal.Normal(0, 1)
@@ -104,6 +111,25 @@ def get_seqsep(idx):
104111
neigh = torch.abs(seqsep)
105112
neigh[neigh > 1] = 0.0 # if bonded -- 1.0 / else 0.0
106113
neigh = sign * neigh
114+
115+
# add cyclic edges
116+
breaks = find_breaks(idx.squeeze().cpu().numpy())
117+
chainids = np.zeros_like(idx.squeeze().cpu().numpy())
118+
for i, b in enumerate(breaks):
119+
chainids[b:] = i+1
120+
chainids = torch.from_numpy(chainids).to(device=idx.device)
121+
122+
# add cyclic edges with multiple chains
123+
if (cyclic is not None):
124+
for chid in torch.unique(chainids):
125+
is_chid = chainids==chid
126+
cur_cyclic = cyclic*is_chid
127+
cur_cres = cur_cyclic.nonzero()
128+
129+
if cur_cyclic.sum()>=2:
130+
neigh[:,cur_cres[-1],cur_cres[0]] = 1
131+
neigh[:,cur_cres[0],cur_cres[-1]] = -1
132+
107133
return neigh.unsqueeze(-1)
108134

109135
def make_full_graph(xyz, pair, idx, top_k=64, kmin=9):

0 commit comments

Comments
 (0)