@@ -334,8 +334,10 @@ def get_pose3D(path, output_dir, type='image'):
334334 print (f"args.n_joints: { args .n_joints } , args.out_joints: { args .out_joints } " )
335335
336336 ## Reload model
337+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
338+
337339 model = {}
338- model ['CFM' ] = CFM (args ).cuda ( )
340+ model ['CFM' ] = CFM (args ).to ( device )
339341
340342 model_dict = model ['CFM' ].state_dict ()
341343 model_path = args .saved_model_path
@@ -400,7 +402,8 @@ def get_3D_pose_from_image(args, keypoints, i, img, model, output_dir):
400402 input_2D = np .expand_dims (input_2D , axis = 0 ) # (1, J, 2)
401403
402404 # Convert to tensor format matching visualize_animal_poses.py
403- input_2D = torch .from_numpy (input_2D .astype ('float32' )).cuda () # (1, J, 2)
405+ device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
406+ input_2D = torch .from_numpy (input_2D .astype ('float32' )).to (device ) # (1, J, 2)
404407 input_2D = input_2D .unsqueeze (0 ) # (1, 1, J, 2)
405408
406409 # Euler sampler for CFM
@@ -418,7 +421,7 @@ def euler_sample(c_2d, y_local, steps, model_3d):
418421
419422 # Single inference without flip augmentation
420423 # Create 3D random noise with shape (1, 1, J, 3)
421- y = torch .randn (input_2D .size (0 ), input_2D .size (1 ), input_2D .size (2 ), 3 ). cuda ( )
424+ y = torch .randn (input_2D .size (0 ), input_2D .size (1 ), input_2D .size (2 ), 3 , device = device )
422425 output_3D = euler_sample (input_2D , y , steps = args .sample_steps , model_3d = model )
423426
424427 output_3D = output_3D [0 :, args .pad ].unsqueeze (1 )
0 commit comments