File tree Expand file tree Collapse file tree
Expand file tree Collapse file tree Original file line number Diff line number Diff line change @@ -105,9 +105,11 @@ def model_fn(
105105 rng : spec .RandomState ,
106106 update_batch_norm : bool ,
107107 use_running_average_bn : Optional [bool ] = None ,
108+ dropout_rate : float = 0.0 ,
108109 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
109110 del mode
110111 del rng
112+ del dropout_rate
111113 variables = {'params' : params , ** model_state }
112114 if update_batch_norm :
113115 logits , new_model_state = self ._model .apply (
Original file line number Diff line number Diff line change @@ -48,10 +48,12 @@ def model_fn(
4848 mode : spec .ForwardPassMode ,
4949 rng : spec .RandomState ,
5050 update_batch_norm : bool ,
51+ dropout_rate : float = 0.0 ,
5152 ) -> Tuple [spec .Tensor , spec .ModelAuxiliaryState ]:
5253 del model_state
5354 del rng
5455 del update_batch_norm
56+ del dropout_rate
5557 train = mode == spec .ForwardPassMode .TRAIN
5658 logits_batch = self ._model .apply (
5759 {'params' : params },
You can’t perform that action at this time.
0 commit comments