File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -37,21 +37,23 @@ def compute(self):
3737 labels = values ['labels' ]
3838 logits = values ['logits' ]
3939 mask = values ['mask' ]
40+ sigmoid = jax .nn .sigmoid
4041
4142 if USE_PYTORCH_DDP :
4243 # Sync labels, logits, and masks across devices.
43- all_values = [np . array ( labels ), np . array ( logits ), np . array ( mask ) ]
44+ all_values = [labels , logits , mask ]
4445 for idx , array in enumerate (all_values ):
4546 tensor = torch .as_tensor (array , device = DEVICE )
4647 # Assumes that the tensors on all devices have the same shape.
4748 all_tensors = [torch .zeros_like (tensor ) for _ in range (N_GPUS )]
4849 dist .all_gather (all_tensors , tensor )
4950 all_values [idx ] = torch .cat (all_tensors ).cpu ().numpy ()
5051 labels , logits , mask = all_values
52+ sigmoid = lambda x : 1 / (1 + np .exp (- x ))
5153
5254 mask = mask .astype (bool )
5355
54- probs = 1 / ( 1 + np . exp ( - logits ) )
56+ probs = sigmoid ( logits )
5557 num_tasks = labels .shape [1 ]
5658 average_precisions = np .full (num_tasks , np .nan )
5759
You can’t perform that action at this time.
0 commit comments