Skip to content

Commit 57b943d

Browse files
committed
update cuda fallback
1 parent 3bf975d commit 57b943d

6 files changed

Lines changed: 25 additions & 13 deletions

File tree

animals/demo/vis_animals.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -334,8 +334,10 @@ def get_pose3D(path, output_dir, type='image'):
334334
print(f"args.n_joints: {args.n_joints}, args.out_joints: {args.out_joints}")
335335

336336
## Reload model
337+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
338+
337339
model = {}
338-
model['CFM'] = CFM(args).cuda()
340+
model['CFM'] = CFM(args).to(device)
339341

340342
model_dict = model['CFM'].state_dict()
341343
model_path = args.saved_model_path
@@ -400,7 +402,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
400402
input_2D = np.expand_dims(input_2D, axis=0) # (1, J, 2)
401403

402404
# Convert to tensor format matching visualize_animal_poses.py
403-
input_2D = torch.from_numpy(input_2D.astype('float32')).cuda() # (1, J, 2)
405+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
406+
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device) # (1, J, 2)
404407
input_2D = input_2D.unsqueeze(0) # (1, 1, J, 2)
405408

406409
# Euler sampler for CFM
@@ -418,7 +421,7 @@ def euler_sample(c_2d, y_local, steps, model_3d):
418421

419422
# Single inference without flip augmentation
420423
# Create 3D random noise with shape (1, 1, J, 3)
421-
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3).cuda()
424+
y = torch.randn(input_2D.size(0), input_2D.size(1), input_2D.size(2), 3, device=device)
422425
output_3D = euler_sample(input_2D, y, steps=args.sample_steps, model_3d=model)
423426

424427
output_3D = output_3D[0:, args.pad].unsqueeze(1)

animals/scripts/main_animal3d.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,8 +264,10 @@ def get_parameter_number(net):
264264
test_dataloader = torch.utils.data.DataLoader(test_dataset, batch_size=args.batch_size,
265265
shuffle=False, num_workers=int(args.workers), pin_memory=True)
266266

267+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
268+
267269
model = {}
268-
model['CFM'] = CFM(args).cuda()
270+
model['CFM'] = CFM(args).to(device)
269271

270272
if args.reload:
271273
model_dict = model['CFM'].state_dict()

demo/vis_in_the_wild.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -213,7 +213,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
213213

214214
input_2D = input_2D[np.newaxis, :, :, :, :]
215215

216-
input_2D = torch.from_numpy(input_2D.astype('float32')).cuda()
216+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
217+
input_2D = torch.from_numpy(input_2D.astype('float32')).to(device)
217218

218219
N = input_2D.size(0)
219220

@@ -229,10 +230,10 @@ def euler_sample(c_2d, y_local, steps, model_3d):
229230

230231
## estimation
231232

232-
y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
233+
y = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
233234
output_3D_non_flip = euler_sample(input_2D[:, 0], y, steps=args.sample_steps, model_3d=model)
234235

235-
y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3).cuda()
236+
y_flip = torch.randn(input_2D.size(0), input_2D.size(2), input_2D.size(3), 3, device=device)
236237
output_3D_flip = euler_sample(input_2D[:, 1], y_flip, steps=args.sample_steps, model_3d=model)
237238

238239
output_3D_flip[:, :, :, 0] *= -1
@@ -280,8 +281,10 @@ def get_pose3D(path, output_dir, type='image'):
280281
# args.type = type
281282

282283
## Reload
284+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
285+
283286
model = {}
284-
model['CFM'] = CFM(args).cuda()
287+
model['CFM'] = CFM(args).to(device)
285288

286289
# if args.reload:
287290
model_dict = model['CFM'].state_dict()

fmpose3d/animals/common/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,15 +220,16 @@ def update(self, val, n=1):
220220

221221

222222
def get_varialbe(split, target):
223+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
223224
num = len(target)
224225
var = []
225226
if split == "train":
226227
for i in range(num):
227-
temp = target[i].requires_grad_(False).contiguous().type(torch.cuda.FloatTensor)
228+
temp = target[i].requires_grad_(False).contiguous().float().to(device)
228229
var.append(temp)
229230
else:
230231
for i in range(num):
231-
temp = target[i].contiguous().cuda().type(torch.cuda.FloatTensor)
232+
temp = target[i].contiguous().float().to(device)
232233
var.append(temp)
233234

234235
return var

fmpose3d/common/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -186,15 +186,16 @@ def update(self, val, n=1):
186186

187187

188188
def get_varialbe(split, target):
189+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
189190
num = len(target)
190191
var = []
191192
if split == "train":
192193
for i in range(num):
193-
temp = target[i].requires_grad_(False).contiguous().type(torch.cuda.FloatTensor)
194+
temp = target[i].requires_grad_(False).contiguous().float().to(device)
194195
var.append(temp)
195196
else:
196197
for i in range(num):
197-
temp = target[i].contiguous().cuda().type(torch.cuda.FloatTensor)
198+
temp = target[i].contiguous().float().to(device)
198199
var.append(temp)
199200

200201
return var

scripts/FMPose3D_main.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -335,8 +335,10 @@ def print_error_action(action_error_sum, is_train):
335335
pin_memory=True,
336336
)
337337

338+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
339+
338340
model = {}
339-
model["CFM"] = CFM(args).cuda()
341+
model["CFM"] = CFM(args).to(device)
340342

341343
if args.reload:
342344
model_dict = model["CFM"].state_dict()

0 commit comments

Comments
 (0)