We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent fdc956b commit 6c888dfCopy full SHA for 6c888df
1 file changed
algoperf/workloads/ogbg/metrics.py
@@ -31,6 +31,9 @@ class MeanAveragePrecision(
31
metrics.CollectingMetric.from_outputs(('logits', 'labels', 'mask'))):
32
"""Computes the mean average precision (mAP) over different tasks."""
33
34
+ def sigmoid_np(x):
35
+ return 1 / (1 + np.exp(-x))
36
+
37
def compute(self):
38
# Matches the official OGB evaluation scheme for mean average precision.
39
values = super().compute()
@@ -49,7 +52,7 @@ def compute(self):
49
52
dist.all_gather(all_tensors, tensor)
50
53
all_values[idx] = torch.cat(all_tensors).cpu().numpy()
51
54
labels, logits, mask = all_values
- sigmoid = lambda x: 1 / (1 + np.exp(-x))
55
+ sigmoid = sigmoid_np
56
57
mask = mask.astype(bool)
58
0 commit comments