1515from einops import rearrange
1616from fmpose3d .models .graph_frames import Graph
1717from functools import partial
18- from einops import rearrange , repeat
18+ from einops import rearrange
1919from timm .models .layers import DropPath
2020
2121class TimeEmbedding (nn .Module ):
@@ -36,8 +36,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
3636 b , f = t .shape [0 ], t .shape [1 ]
3737 half_dim = self .dim // 2
3838
39- # Gaussian Fourier features: sin(2π B t), cos(2π B t)
40- # t: (B,F,1,1) -> (B,F,1,1,1) -> broadcast with (1,1,1,1,half_dim)
4139 angles = (2 * math .pi ) * t .to (torch .float32 ).unsqueeze (- 1 ) * self .B .view (1 , 1 , 1 , 1 , half_dim )
4240 sin = torch .sin (angles )
4341 cos = torch .cos (angles )
@@ -206,13 +204,13 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
206204 self .fc1 = nn .Linear (dim_in , dim_hid )
207205 self .act = act_layer ()
208206 self .drop = nn .Dropout (drop ) if drop > 0 else nn .Identity ()
209- self .fc5 = nn .Linear (dim_hid , dim_out )
207+ self .fc2 = nn .Linear (dim_hid , dim_out )
210208
211209 def forward (self , x ):
212210 x = self .fc1 (x )
213211 x = self .act (x )
214212 x = self .drop (x )
215- x = self .fc5 (x )
213+ x = self .fc2 (x )
216214 return x
217215
218216class Model (nn .Module ):
@@ -223,10 +221,10 @@ def __init__(self, args):
223221 # Register as buffer (not parameter) to follow module device automatically
224222 self .register_buffer ('A' , torch .tensor (self .graph .A , dtype = torch .float32 ))
225223
226- self .t_embed_dim = 16
224+ self .t_embed_dim = 32
227225 self .time_embed = TimeEmbedding (self .t_embed_dim , hidden_dim = 64 )
228- self . encoder_pose_2d = encoder ( 2 , args . channel // 2 , args . channel // 2 - self . t_embed_dim // 2 )
229- self .encoder_y_t = encoder (3 , args .channel // 2 , args .channel // 2 - self . t_embed_dim // 2 )
226+
227+ self .encoder = encoder (2 + 3 + self . t_embed_dim , args .channel // 2 , args .channel )
230228
231229 self .FMPose3D = FMPose3D (args .layers , args .channel , args .d_hid , args .token_dim , self .A , length = args .n_joints ) # 256
232230 self .pred_mu = decoder (args .channel , args .channel // 2 , 3 )
@@ -235,24 +233,16 @@ def forward(self, pose_2d, y_t, t):
235233 # pose_2d: (B,F,J,2) y_t: (B,F,J,3) t: (B,F,1,1)
236234 b , f , j , _ = pose_2d .shape
237235
238- # Ensure t has the correct shape (B,F,1,1)
239- if t .shape [1 ] == 1 and f > 1 :
240- t = t .expand (b , f , 1 , 1 ).contiguous ()
241-
242236 # build time embedding
243237 t_emb = self .time_embed (t ) # (B,F,t_dim)
244238 t_emb = t_emb .unsqueeze (2 ).expand (b , f , j , self .t_embed_dim ).contiguous () # (B,F,J,t_dim)
245239
246- pose_2d_emb = self .encoder_pose_2d (pose_2d )
247- y_t_emb = self .encoder_y_t (y_t )
248-
249- in_emb = torch .cat ([pose_2d_emb , y_t_emb , t_emb ], dim = - 1 ) # (B,F,J,dim)
250- in_emb = rearrange (in_emb , 'b f j c -> (b f) j c' ).contiguous () # (B*F,J,in)
251-
252- # encoder -> model -> regression head
253- h = self .FMPose3D (in_emb )
254- v = self .pred_mu (h ) # (B*F,J,3)
240+ x_in = torch .cat ([pose_2d , y_t , t_emb ], dim = - 1 ) # (B,F,J,2+3+t_dim)
241+ x_in = rearrange (x_in , 'b f j c -> (b f) j c' ).contiguous () # (B*F,J,in)
255242
243+ in_emb = self .encoder (x_in )
244+ features = self .FMPose3D (in_emb )
245+ v = self .pred_mu (features ) # (B*F,J,3)
256246 v = rearrange (v , '(b f) j c -> b f j c' , b = b , f = f ).contiguous () # (B,F,J,3)
257247 return v
258248
0 commit comments