Skip to content

Commit fdc956b

Browse files
authored
Update metrics.py
The problem with torchrun and jax seems to be caused by jax.nn.sigmoid.
1 parent 9c189ad commit fdc956b

1 file changed

Lines changed: 4 additions & 2 deletions

File tree

algoperf/workloads/ogbg/metrics.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,23 @@ def compute(self):
3737
labels = values['labels']
3838
logits = values['logits']
3939
mask = values['mask']
40+
sigmoid = jax.nn.sigmoid
4041

4142
if USE_PYTORCH_DDP:
4243
# Sync labels, logits, and masks across devices.
43-
all_values = [np.array(labels), np.array(logits), np.array(mask)]
44+
all_values = [labels, logits, mask]
4445
for idx, array in enumerate(all_values):
4546
tensor = torch.as_tensor(array, device=DEVICE)
4647
# Assumes that the tensors on all devices have the same shape.
4748
all_tensors = [torch.zeros_like(tensor) for _ in range(N_GPUS)]
4849
dist.all_gather(all_tensors, tensor)
4950
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
5051
labels, logits, mask = all_values
52+
sigmoid = lambda x: 1 / (1 + np.exp(-x))
5153

5254
mask = mask.astype(bool)
5355

54-
probs = 1 / (1 + np.exp(-logits))
56+
probs = sigmoid(logits)
5557
num_tasks = labels.shape[1]
5658
average_precisions = np.full(num_tasks, np.nan)
5759

0 commit comments

Comments
 (0)