Skip to content

Commit f9fbbab

Browse files
Merge pull request mlcommons#875 from mlcommons/dropout_support
Dropout support
2 parents 8723937 + a151382 commit f9fbbab

41 files changed

Lines changed: 1158 additions & 807 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

algoperf/jax_utils.py

Lines changed: 130 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,130 @@
1+
from collections.abc import Sequence
2+
3+
import flax.linen as nn
4+
from flax.linen.module import compact
5+
from flax.linen.module import merge_param
6+
from flax.linen.module import Module
7+
from flax.typing import PRNGKey
8+
import jax
9+
from jax import lax
10+
from jax import random
11+
import jax.numpy as jnp
12+
13+
14+
# Custom Layers
15+
class Dropout(Module):
16+
# pylint: disable=line-too-long
17+
"""Create a dropout layer.
18+
Forked from
19+
https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout.
20+
The reference dropout implementation is modified support changes
21+
to dropout rate during training by:
22+
1) adding rate argument to the __call__ method.
23+
2) removing the if-else condition to check for edge cases, which
24+
will trigger a recompile for jitted code.
25+
26+
.. note::
27+
When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure
28+
to include an RNG seed named ``'dropout'``. Dropout isn't necessary for
29+
variable initialization.
30+
31+
Example usage::
32+
33+
>>> import flax.linen as nn
34+
>>> import jax, jax.numpy as jnp
35+
36+
>>> class MLP(nn.Module):
37+
... @nn.compact
38+
... def __call__(self, x, train):
39+
... x = nn.Dense(4)(x)
40+
... x = nn.Dropout(0.5, deterministic=not train)(x)
41+
... return x
42+
43+
>>> model = MLP()
44+
>>> x = jnp.ones((1, 3))
45+
>>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout
46+
>>> model.apply(variables, x, train=False) # don't use dropout
47+
Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32)
48+
>>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout
49+
Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32)
50+
51+
Attributes:
52+
rate: the dropout probability. (_not_ the keep rate!)
53+
broadcast_dims: dimensions that will share the same dropout mask
54+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
55+
and masked, whereas if true, no mask is applied and the inputs are
56+
returned as is.
57+
rng_collection: the rng collection name to use when requesting an rng
58+
key.
59+
"""
60+
61+
rate: float | None = None
62+
broadcast_dims: Sequence[int] = ()
63+
deterministic: bool | None = None
64+
rng_collection: str = "dropout"
65+
legacy: bool = False
66+
67+
@compact
68+
def __call__(
69+
self,
70+
inputs,
71+
deterministic: bool | None = None,
72+
rate: float | None = None,
73+
rng: PRNGKey | None = None,
74+
):
75+
"""Applies a random dropout mask to the input.
76+
77+
Args:
78+
inputs: the inputs that should be randomly masked.
79+
deterministic: if false the inputs are scaled by ``1 / (1 - rate)``
80+
and masked, whereas if true, no mask is applied and the inputs are
81+
returned as is.
82+
rate: the dropout probability. (_not_ the keep rate!)
83+
rng: an optional PRNGKey used as the random key, if not specified,
84+
one will be generated using ``make_rng`` with the
85+
``rng_collection`` name.
86+
87+
Returns:
88+
The masked inputs reweighted to preserve mean.
89+
"""
90+
deterministic = merge_param("deterministic",
91+
self.deterministic,
92+
deterministic)
93+
94+
# Override self.rate if rate is passed to __call__
95+
if rate is None:
96+
rate = self.rate
97+
98+
if self.legacy:
99+
if rate == 0.0:
100+
return inputs
101+
102+
# Prevent gradient NaNs in 1.0 edge-case.
103+
if rate == 1.0:
104+
return jnp.zeros_like(inputs)
105+
106+
if deterministic:
107+
return inputs
108+
109+
keep_prob = 1.0 - rate
110+
if rng is None:
111+
rng = self.make_rng(self.rng_collection)
112+
broadcast_shape = list(inputs.shape)
113+
for dim in self.broadcast_dims:
114+
broadcast_shape[dim] = 1
115+
mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape)
116+
mask = jnp.broadcast_to(mask, inputs.shape)
117+
return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs))
118+
119+
120+
# Utilities for debugging
121+
def print_jax_model_summary(model, fake_inputs):
122+
"""Prints a summary of the jax module."""
123+
tabulate_fn = nn.tabulate(
124+
model,
125+
jax.random.PRNGKey(0),
126+
console_kwargs={
127+
"force_terminal": False, "force_jupyter": False, "width": 240
128+
},
129+
)
130+
print(tabulate_fn(fake_inputs, train=False))

algoperf/pytorch_utils.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,10 @@
55
import jax
66
import tensorflow as tf
77
import torch
8+
from torch import nn
9+
from torch import Tensor
810
import torch.distributed as dist
11+
import torch.nn.functional as F
912

1013
from algoperf import spec
1114
from algoperf.profiler import Profiler
@@ -77,3 +80,41 @@ def update_batch_norm_fn(module: spec.ParameterContainer,
7780
module.momentum = 0.0
7881
elif hasattr(module, 'momentum_backup'):
7982
module.momentum = module.momentum_backup
83+
84+
85+
class CustomDropout(nn.Module):
86+
"""A module around torch.nn.functional.dropout."""
87+
88+
def __init__(self):
89+
super().__init__()
90+
self._supports_custom_dropout = True
91+
92+
def forward(self, x: Tensor, p: float) -> Tensor:
93+
return F.dropout(x, p, training=self.training)
94+
95+
96+
class CustomDropout2d(nn.Module):
97+
"""A module around torch.nn.functional.dropout2d."""
98+
99+
def __init__(self):
100+
super().__init__()
101+
self._supports_custom_dropout = True
102+
103+
def forward(self, x: Tensor, p: float) -> Tensor:
104+
return F.dropout2d(x, p, training=self.training)
105+
106+
107+
class SequentialWithDropout(nn.Sequential):
108+
"""Sequential of modules with dropout."""
109+
110+
def __init__(self, *args, **kwargs):
111+
super().__init__(*args, **kwargs)
112+
self._supports_custom_dropout = True
113+
114+
def forward(self, x: Tensor, p: float) -> Tensor:
115+
for module in self:
116+
if getattr(module, '_supports_custom_dropout', False):
117+
x = module(x, p)
118+
else:
119+
x = module(x)
120+
return x

algoperf/spec.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -247,7 +247,8 @@ def init_model_fn(self,
247247
# ModelAuxiliaryState,
248248
# ForwardPassMode,
249249
# RandomState,
250-
# bool],
250+
# bool,
251+
# float],
251252
# Tensor]
252253
@abc.abstractmethod
253254
def model_fn(self,
@@ -256,7 +257,8 @@ def model_fn(self,
256257
model_state: ModelAuxiliaryState,
257258
mode: ForwardPassMode,
258259
rng: RandomState,
259-
update_batch_norm: bool) -> Tuple[Tensor, ModelAuxiliaryState]:
260+
update_batch_norm: bool,
261+
dropout_rate: float) -> Tuple[Tensor, ModelAuxiliaryState]:
260262
"""Return logits_batch"""
261263
# Possible side effect of updating BN.
262264

algoperf/workloads/cifar/cifar_jax/workload.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -79,14 +79,8 @@ def sync_batch_stats(
7979
new_model_state['batch_stats'] = avg_fn(model_state['batch_stats'])
8080
return new_model_state
8181

82-
def init_model_fn(
83-
self,
84-
rng: spec.RandomState,
85-
dropout_rate: Optional[float] = None,
86-
aux_dropout_rate: Optional[float] = None) -> spec.ModelInitState:
82+
def init_model_fn(self, rng: spec.RandomState) -> spec.ModelInitState:
8783
"""Dropout is unused."""
88-
del dropout_rate
89-
del aux_dropout_rate
9084
model_cls = getattr(models, 'ResNet18')
9185
model = model_cls(num_classes=self._num_classes, dtype=jnp.float32)
9286
self._model = model

algoperf/workloads/criteo1tb/criteo1tb_jax/models.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,14 @@
11
"""A JAX implementation of DLRM-Small."""
2-
32
from typing import Sequence
43

54
import flax.linen as nn
65
from jax import nn as jnn
76
import jax.numpy as jnp
87

8+
from algoperf.jax_utils import Dropout
9+
10+
DROPOUT_RATE = 0.0
11+
912

1013
class DLRMResNet(nn.Module):
1114
"""Define a DLRMResNet model.
@@ -23,12 +26,13 @@ class DLRMResNet(nn.Module):
2326
mlp_bottom_dims: Sequence[int] = (256, 256, 256)
2427
mlp_top_dims: Sequence[int] = (256, 256, 256, 256, 1)
2528
embed_dim: int = 128
26-
dropout_rate: float = 0.0
29+
dropout_rate: float = DROPOUT_RATE
2730
use_layer_norm: bool = False # Unused.
2831
embedding_init_multiplier: float = None # Unused
2932

3033
@nn.compact
31-
def __call__(self, x, train):
34+
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
35+
3236
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
3337
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
3438

@@ -88,8 +92,8 @@ def scaled_init(key, shape, dtype=jnp.float_):
8892
stddev=jnp.sqrt(1.0 / mlp_top_dims[layer_idx])))(
8993
top_mlp_input)
9094
x = nn.relu(x)
91-
if self.dropout_rate and layer_idx == num_layers_top - 2:
92-
x = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(x)
95+
if dropout_rate and layer_idx == num_layers_top - 2:
96+
x = Dropout(dropout_rate, deterministic=not train)(x, rate=dropout_rate)
9397
top_mlp_input += x
9498
# In the DLRM model the last layer width is always 1. We can hardcode that
9599
# below.
@@ -151,7 +155,8 @@ class DlrmSmall(nn.Module):
151155
embedding_init_multiplier: float = None
152156

153157
@nn.compact
154-
def __call__(self, x, train):
158+
def __call__(self, x, train, dropout_rate=DROPOUT_RATE):
159+
155160
bot_mlp_input, cat_features = jnp.split(x, [self.num_dense_features], 1)
156161
cat_features = jnp.asarray(cat_features, dtype=jnp.int32)
157162

@@ -210,10 +215,10 @@ def scaled_init(key, shape, dtype=jnp.float_):
210215
top_mlp_input = nn.relu(top_mlp_input)
211216
if self.use_layer_norm:
212217
top_mlp_input = nn.LayerNorm()(top_mlp_input)
213-
if (self.dropout_rate is not None and self.dropout_rate > 0.0 and
218+
if (dropout_rate is not None and dropout_rate > 0.0 and
214219
layer_idx == num_layers_top - 2):
215-
top_mlp_input = nn.Dropout(
216-
rate=self.dropout_rate, deterministic=not train)(
217-
top_mlp_input)
220+
top_mlp_input = Dropout(
221+
dropout_rate, deterministic=not train)(
222+
top_mlp_input, rate=dropout_rate)
218223
logits = top_mlp_input
219224
return logits

algoperf/workloads/criteo1tb/criteo1tb_jax/workload.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -72,36 +72,34 @@ def loss_fn(
7272
def init_model_fn(
7373
self,
7474
rng: spec.RandomState,
75-
dropout_rate: Optional[float] = None,
76-
aux_dropout_rate: Optional[float] = None,
7775
tabulate: Optional[bool] = False,
7876
) -> spec.ModelInitState:
7977
"""Only dropout is used."""
80-
del aux_dropout_rate
8178
if self.use_resnet:
8279
model_class = models.DLRMResNet
8380
else:
8481
model_class = models.DlrmSmall
82+
8583
self._model = model_class(
8684
vocab_size=self.vocab_size,
8785
num_dense_features=self.num_dense_features,
8886
mlp_bottom_dims=self.mlp_bottom_dims,
8987
mlp_top_dims=self.mlp_top_dims,
9088
embed_dim=self.embed_dim,
91-
dropout_rate=dropout_rate,
9289
use_layer_norm=self.use_layer_norm,
9390
embedding_init_multiplier=self.embedding_init_multiplier)
9491

95-
params_rng, dropout_rng = jax.random.split(rng)
92+
params_rng, _ = jax.random.split(rng)
9693
init_fake_batch_size = 2
9794
num_categorical_features = 26
9895
num_dense_features = 13
9996
input_size = num_dense_features + num_categorical_features
10097
input_shape = (init_fake_batch_size, input_size)
10198
init_fn = functools.partial(self._model.init, train=False)
102-
initial_variables = jax.jit(init_fn)(
103-
{'params': params_rng, 'dropout': dropout_rng},
104-
jnp.ones(input_shape, jnp.float32))
99+
initial_variables = jax.jit(init_fn)({
100+
'params': params_rng,
101+
},
102+
jnp.ones(input_shape, jnp.float32))
105103
initial_params = initial_variables['params']
106104
self._param_shapes = param_utils.jax_param_shapes(initial_params)
107105
self._param_types = param_utils.jax_param_types(self._param_shapes)
@@ -117,14 +115,17 @@ def model_fn(
117115
model_state: spec.ModelAuxiliaryState,
118116
mode: spec.ForwardPassMode,
119117
rng: spec.RandomState,
120-
update_batch_norm: bool) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
118+
update_batch_norm: bool,
119+
dropout_rate: float = models.DROPOUT_RATE
120+
) -> Tuple[spec.Tensor, spec.ModelAuxiliaryState]:
121121
del model_state
122122
del update_batch_norm
123123
inputs = augmented_and_preprocessed_input_batch['inputs']
124124
train = mode == spec.ForwardPassMode.TRAIN
125125
apply_kwargs = {'train': train}
126126
if train:
127127
apply_kwargs['rngs'] = {'dropout': rng}
128+
apply_kwargs['dropout_rate'] = dropout_rate
128129
logits_batch = self._model.apply({'params': params}, inputs, **apply_kwargs)
129130
return logits_batch, None
130131

0 commit comments

Comments
 (0)