Skip to content

Commit d76b239

Browse files
committed
update the model structure
1 parent c6c148c commit d76b239

1 file changed

Lines changed: 11 additions & 21 deletions

File tree

fmpose3d/models/model_GAMLP.py

Lines changed: 11 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from einops import rearrange
1616
from fmpose3d.models.graph_frames import Graph
1717
from functools import partial
18-
from einops import rearrange, repeat
18+
from einops import rearrange
1919
from timm.models.layers import DropPath
2020

2121
class TimeEmbedding(nn.Module):
@@ -36,8 +36,6 @@ def forward(self, t: torch.Tensor) -> torch.Tensor:
3636
b, f = t.shape[0], t.shape[1]
3737
half_dim = self.dim // 2
3838

39-
# Gaussian Fourier features: sin(2π B t), cos(2π B t)
40-
# t: (B,F,1,1) -> (B,F,1,1,1) -> broadcast with (1,1,1,1,half_dim)
4139
angles = (2 * math.pi) * t.to(torch.float32).unsqueeze(-1) * self.B.view(1, 1, 1, 1, half_dim)
4240
sin = torch.sin(angles)
4341
cos = torch.cos(angles)
@@ -206,13 +204,13 @@ def __init__(self, in_features, hidden_features=None, out_features=None, act_lay
206204
self.fc1 = nn.Linear(dim_in, dim_hid)
207205
self.act = act_layer()
208206
self.drop = nn.Dropout(drop) if drop > 0 else nn.Identity()
209-
self.fc5 = nn.Linear(dim_hid, dim_out)
207+
self.fc2= nn.Linear(dim_hid, dim_out)
210208

211209
def forward(self, x):
212210
x = self.fc1(x)
213211
x = self.act(x)
214212
x = self.drop(x)
215-
x = self.fc5(x)
213+
x = self.fc2(x)
216214
return x
217215

218216
class Model(nn.Module):
@@ -223,10 +221,10 @@ def __init__(self, args):
223221
# Register as buffer (not parameter) to follow module device automatically
224222
self.register_buffer('A', torch.tensor(self.graph.A, dtype=torch.float32))
225223

226-
self.t_embed_dim = 16
224+
self.t_embed_dim = 32
227225
self.time_embed = TimeEmbedding(self.t_embed_dim, hidden_dim=64)
228-
self.encoder_pose_2d = encoder(2, args.channel//2, args.channel//2-self.t_embed_dim//2)
229-
self.encoder_y_t = encoder(3, args.channel//2, args.channel//2-self.t_embed_dim//2)
226+
227+
self.encoder = encoder(2 + 3 + self.t_embed_dim, args.channel//2, args.channel)
230228

231229
self.FMPose3D = FMPose3D(args.layers, args.channel, args.d_hid, args.token_dim, self.A, length=args.n_joints) # 256
232230
self.pred_mu = decoder(args.channel, args.channel//2, 3)
@@ -235,24 +233,16 @@ def forward(self, pose_2d, y_t, t):
235233
# pose_2d: (B,F,J,2) y_t: (B,F,J,3) t: (B,F,1,1)
236234
b, f, j, _ = pose_2d.shape
237235

238-
# Ensure t has the correct shape (B,F,1,1)
239-
if t.shape[1] == 1 and f > 1:
240-
t = t.expand(b, f, 1, 1).contiguous()
241-
242236
# build time embedding
243237
t_emb = self.time_embed(t) # (B,F,t_dim)
244238
t_emb = t_emb.unsqueeze(2).expand(b, f, j, self.t_embed_dim).contiguous() # (B,F,J,t_dim)
245239

246-
pose_2d_emb = self.encoder_pose_2d(pose_2d)
247-
y_t_emb = self.encoder_y_t(y_t)
248-
249-
in_emb = torch.cat([pose_2d_emb, y_t_emb, t_emb], dim=-1) # (B,F,J,dim)
250-
in_emb = rearrange(in_emb, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)
251-
252-
# encoder -> model -> regression head
253-
h = self.FMPose3D(in_emb)
254-
v = self.pred_mu(h) # (B*F,J,3)
240+
x_in = torch.cat([pose_2d, y_t, t_emb], dim=-1) # (B,F,J,2+3+t_dim)
241+
x_in = rearrange(x_in, 'b f j c -> (b f) j c').contiguous() # (B*F,J,in)
255242

243+
in_emb = self.encoder(x_in)
244+
features = self.FMPose3D(in_emb)
245+
v = self.pred_mu(features) # (B*F,J,3)
256246
v = rearrange(v, '(b f) j c -> b f j c', b=b, f=f).contiguous() # (B,F,J,3)
257247
return v
258248

0 commit comments

Comments
 (0)