-
NumPy to JAX: The-Basics
This first exercise is a straightforward warm-up to familiarize yourself with the JAX NumPy API. 1. Create a function that takes two `jnp.ndarray`'s, `W` and `x`, as well as a `jnp.ndarray` `b`,...
-
The JAX approach to PRNG
JAX handles pseudo-random number generation (PRNG) differently than NumPy, which uses a global state. JAX, on the other hand, makes the state of the PRNG explicit. This is a design choice that...
1