-
Implement a Simple Linear Regression in JAX
Your task is to implement a simple linear regression model from scratch using JAX. You'll need to define the model, a loss function (like Mean Squared Error), and a gradient descent update rule....
-
Build a Custom ReLU Activation Function
Implement the Rectified Linear Unit (ReLU) activation function in JAX. Then, use `jax.grad` to find its derivative. The ReLU function is defined as: $ReLU(x) = max(0, x)$ **Verification:** - For...
-
Custom Gradient with `jax.custom_vjp`
Implement a function with a custom gradient using `jax.custom_vjp`. This is useful for numerical stability or for defining gradients for non-differentiable operations. A good example is a function...
-
Physics-Informed Neural Network (PINN) for an ODE
### Description Solve a simple Ordinary Differential Equation (ODE) using a Physics-Informed Neural Network. A PINN is a neural network that is trained to satisfy both the data and the underlying...
-
Gradient Reversal Layer
### Description Implement a Gradient Reversal Layer (GRL), a key component in Domain-Adversarial Neural Networks (DANNs). [1] The GRL acts as an identity function during the forward pass but...
-
Simple Differentiable Renderer
### Description Modern 3D deep learning often relies on differentiable rendering, allowing gradients to flow from a 2D rendered image back to 3D scene parameters. [1] Your task is to implement a...