|
| 1 | +from collections.abc import Sequence |
| 2 | + |
| 3 | +import flax.linen as nn |
| 4 | +from flax.linen.module import compact |
| 5 | +from flax.linen.module import merge_param |
| 6 | +from flax.linen.module import Module |
| 7 | +from flax.typing import PRNGKey |
| 8 | +import jax |
| 9 | +from jax import lax |
| 10 | +from jax import random |
| 11 | +import jax.numpy as jnp |
| 12 | + |
| 13 | + |
| 14 | +# Custom Layers |
| 15 | +class Dropout(Module): |
| 16 | + # pylint: disable=line-too-long |
| 17 | + """Create a dropout layer. |
| 18 | + Forked from |
| 19 | + https://flax-linen.readthedocs.io/en/latest/_modules/flax/linen/stochastic.html#Dropout. |
| 20 | + The reference dropout implementation is modified support changes |
| 21 | + to dropout rate during training by: |
| 22 | + 1) adding rate argument to the __call__ method. |
| 23 | + 2) removing the if-else condition to check for edge cases, which |
| 24 | + will trigger a recompile for jitted code. |
| 25 | +
|
| 26 | + .. note:: |
| 27 | + When using :meth:`Module.apply() <flax.linen.Module.apply>`, make sure |
| 28 | + to include an RNG seed named ``'dropout'``. Dropout isn't necessary for |
| 29 | + variable initialization. |
| 30 | +
|
| 31 | + Example usage:: |
| 32 | +
|
| 33 | + >>> import flax.linen as nn |
| 34 | + >>> import jax, jax.numpy as jnp |
| 35 | +
|
| 36 | + >>> class MLP(nn.Module): |
| 37 | + ... @nn.compact |
| 38 | + ... def __call__(self, x, train): |
| 39 | + ... x = nn.Dense(4)(x) |
| 40 | + ... x = nn.Dropout(0.5, deterministic=not train)(x) |
| 41 | + ... return x |
| 42 | +
|
| 43 | + >>> model = MLP() |
| 44 | + >>> x = jnp.ones((1, 3)) |
| 45 | + >>> variables = model.init(jax.random.key(0), x, train=False) # don't use dropout |
| 46 | + >>> model.apply(variables, x, train=False) # don't use dropout |
| 47 | + Array([[-0.17875527, 1.6255447 , -1.2431065 , -0.02554005]], dtype=float32) |
| 48 | + >>> model.apply(variables, x, train=True, rngs={'dropout': jax.random.key(1)}) # use dropout |
| 49 | + Array([[-0.35751054, 3.2510893 , 0. , 0. ]], dtype=float32) |
| 50 | +
|
| 51 | + Attributes: |
| 52 | + rate: the dropout probability. (_not_ the keep rate!) |
| 53 | + broadcast_dims: dimensions that will share the same dropout mask |
| 54 | + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` |
| 55 | + and masked, whereas if true, no mask is applied and the inputs are |
| 56 | + returned as is. |
| 57 | + rng_collection: the rng collection name to use when requesting an rng |
| 58 | + key. |
| 59 | + """ |
| 60 | + |
| 61 | + rate: float | None = None |
| 62 | + broadcast_dims: Sequence[int] = () |
| 63 | + deterministic: bool | None = None |
| 64 | + rng_collection: str = "dropout" |
| 65 | + legacy: bool = False |
| 66 | + |
| 67 | + @compact |
| 68 | + def __call__( |
| 69 | + self, |
| 70 | + inputs, |
| 71 | + deterministic: bool | None = None, |
| 72 | + rate: float | None = None, |
| 73 | + rng: PRNGKey | None = None, |
| 74 | + ): |
| 75 | + """Applies a random dropout mask to the input. |
| 76 | +
|
| 77 | + Args: |
| 78 | + inputs: the inputs that should be randomly masked. |
| 79 | + deterministic: if false the inputs are scaled by ``1 / (1 - rate)`` |
| 80 | + and masked, whereas if true, no mask is applied and the inputs are |
| 81 | + returned as is. |
| 82 | + rate: the dropout probability. (_not_ the keep rate!) |
| 83 | + rng: an optional PRNGKey used as the random key, if not specified, |
| 84 | + one will be generated using ``make_rng`` with the |
| 85 | + ``rng_collection`` name. |
| 86 | +
|
| 87 | + Returns: |
| 88 | + The masked inputs reweighted to preserve mean. |
| 89 | + """ |
| 90 | + deterministic = merge_param("deterministic", |
| 91 | + self.deterministic, |
| 92 | + deterministic) |
| 93 | + |
| 94 | + # Override self.rate if rate is passed to __call__ |
| 95 | + if rate is None: |
| 96 | + rate = self.rate |
| 97 | + |
| 98 | + if self.legacy: |
| 99 | + if rate == 0.0: |
| 100 | + return inputs |
| 101 | + |
| 102 | + # Prevent gradient NaNs in 1.0 edge-case. |
| 103 | + if rate == 1.0: |
| 104 | + return jnp.zeros_like(inputs) |
| 105 | + |
| 106 | + if deterministic: |
| 107 | + return inputs |
| 108 | + |
| 109 | + keep_prob = 1.0 - rate |
| 110 | + if rng is None: |
| 111 | + rng = self.make_rng(self.rng_collection) |
| 112 | + broadcast_shape = list(inputs.shape) |
| 113 | + for dim in self.broadcast_dims: |
| 114 | + broadcast_shape[dim] = 1 |
| 115 | + mask = random.bernoulli(rng, p=keep_prob, shape=broadcast_shape) |
| 116 | + mask = jnp.broadcast_to(mask, inputs.shape) |
| 117 | + return lax.select(mask, inputs / keep_prob, jnp.zeros_like(inputs)) |
| 118 | + |
| 119 | + |
| 120 | +# Utilities for debugging |
| 121 | +def print_jax_model_summary(model, fake_inputs): |
| 122 | + """Prints a summary of the jax module.""" |
| 123 | + tabulate_fn = nn.tabulate( |
| 124 | + model, |
| 125 | + jax.random.PRNGKey(0), |
| 126 | + console_kwargs={ |
| 127 | + "force_terminal": False, "force_jupyter": False, "width": 240 |
| 128 | + }, |
| 129 | + ) |
| 130 | + print(tabulate_fn(fake_inputs, train=False)) |
0 commit comments