ML Katas

Implementing Dropout

medium (<30 mins) dropout regularization jax deep-learning
yesterday by E

Implement the dropout regularization technique in JAX. This involves randomly setting a fraction of input units to 0 at each update during training time. Remember that dropout should only be active during training. A key challenge in JAX is managing the random keys.

Task: - Write a dropout function that takes an input, a dropout rate, a random key, and a boolean indicating whether it's training time. - During training, it should randomly zero out elements of the input. - During inference, it should return the input as is.

Verification: - During training, some elements of the output should be zero. - During inference, the output should be identical to the input.