Backward Pass
ML frameworks implement auto differentiation for us, so we usually do not need to implement the backward pass ourselves. Nonetheless, it is important to understand how the gradients are propagated to be able to reason about memory usage and the computational overhead of the backprop.
What are the Forward and the Backward Pass?
The Forward pass is the "main formula" of the model. It is the primary computation of the model, transforming inputs to outputs. It is executed during both inference (to get predictions) and training (to compute the loss).
The Backward Pass (Backpropagation) is the chain rule of calculus applied in reverse. It computes the gradient of the loss function with respect to every parameter in the model. These gradients indicate the direction and magnitude to adjust each parameter to minimize error.
The forward pass produces predictions. During training, these predictions are compared to a ground truth value to calculate a loss. The backward pass produces gradients, or directions for each parameter in the model. The gradients are consumed by the optimizer to update the weights used in both the forward and the backward pass.
How is it computed?
Most functions in ML are differentiable (or have defined subgradients for points like x=0 in ReLU). Therefore, when ML frameworks developers implement a function, they implement its forward and backward methods.
A basic example using jax.grad, the derivative of the Sine function sin is simply the Cosine cos. When we apply jax.grad to jnp.sin, jax will internally call jnp.sin.backward (not the exact internal name), which is jnp.cos.
import jax.numpy as jnp
grad_sin = jax.grad(jnp.sin)
grad_sin(0.2) == jnp.cos(0.2)
Let's take a look at a more complex example:
import jax
import jax.numpy as jnp
key0, key1, key2 = jax.random.split(jax.random.key(0), 3)
b, d, f = 16, 64, 32
x = jax.random.normal(key0, (b, d))
w_0 = jax.random.normal(key1, (d, f))
w_1 = jax.random.normal(key2, (f, d))
def mlp(args):
x, w_0, w_1 = args
z = x @ w_0
z_relu = jax.nn.relu(z)
out = z_relu @ w_1
return 0.5 * jnp.sum(out ** 2)
grad_mlp = jax.grad(mlp)
# jax.grad only returns the gradients of the first argument
# so we pass all our arguments as a tuple
grad_mlp((x, w_0, w_1))
This is a classic 2 layers MLP with a ReLU activation in between. \[\text{ReLU}(x W_0) W_1\]
- The
Forward Passsimply executes the code we wrote. - The
Backward Passtakes the output of theForward Passand executes the backward methods in reverse order by propagating gradients backward. - Some derivatives require the original activation from theForward passso we need to individually store them during theforward call.
Walking through the backward pass
0.5 * jnp.sum(out ** 2)The derivative is simplyoutout = z_relu @ w_1Here we need to compute the gradients ofw_1which will be used to updatew_1by the optimizer and the gradients ofz_reluthat will be backpropagated.dL/dW1 = z_relu.T @ grads((b, f).T @ (b, d) -> (f, d))dL/dZ_relu = grads @ w1.T((b, d) @ (f, d).T -> (b, d))
z_relu = jax.nn.relu(z)ReLUis defined asrelu(x) = max(0, x). Its derivative is therefored_relu(x) = 0 if x <= 0 else 1. We then multiply the derivative with the gradients.- Performance Note: Storing values in
HBM(High Bandwidth Memory) is expensive. For element-wise operations like ReLU, it is often faster torecomputethe activation during the backward pass using the cached input (z) rather than storing the output (z_relu) and reading it back. This is known asactivation recomputationorrematerialization.
- Performance Note: Storing values in
z = x @ w_0. Just like for layer 1:dL/dW0 = x.T @ grads((b, d).T @ (b, f) -> (d, f))dL/dx = grads @ w0.T((b, f) @ (d, f).T -> (b, d))
First let's rewrite the MLP implementation to cache the intermediate activations:
def mlp_activations(x, w_0, w_1):
activations = [x]
z = x @ w_0
activations.append(z)
z_relu = jax.nn.relu(z)
activations.append(z_relu)
out = z_relu @ w_1
activations.append(out)
return 0.5 * jnp.sum(out ** 2), activations
Now let's implement the backward pass:
def manual_mlp_grad(x, w_0, w_1):
# Forward
_, activations = mlp_activations(x, w_0, w_1)
# Pop out, shape (b, d)
out = activations.pop()
# 1. Derivative dL/dOut = out
grads = out
# 2. Derivative of Layer 1
z_relu = activations.pop()
# dL/dW1
grads_w_1 = z_relu.T @ grads
# dL/dZ_relu
grads_z_relu = grads @ w_1.T
# 3. Derivative of ReLU
z = activations.pop()
grads_z = jnp.where(z > 0, 1, 0) * grads_z_relu
# 4. Derivative of Layer 0
x = activations.pop()
# dL/dW0
grads_w_0 = x.T @ grads_z
# dL/dx
grads_x = grads_z @ w_0.T
return grads_x, grads_w_0, grads_w_1
Correctness check:
import numpy as np
manual_out = manual_mlp_grad(x, w_0, w_1)
jax_out = grad_mlp((x, w_0, w_1))
for manual, autograd in zip(manual_out, jax_out):
np.testing.assert_allclose(manual, autograd)
Performance Implications
Flops
As we have seen, each matrix multiplication in the forward pass requires two matrix multiplications in the backward pass. Hence, the number of flops in the backward pass can easily be approximated as twice the flops of the forward pass.
Memory Usage
Since we need to store the intermediate activations during the forward pass, our model requires a lot more memory during training than during inference. A common rule of thumb is that Training memory is ~3x-4x Inference memory for the same batch size, primarily due to the need to store these intermediate activations. Furthermore, constantly writing and reading previous activations saturates the memory bandwidth which slows down prefetching of other parameters.
It is crucial to study which activations are being cached, and actively find opportunities for recomputations when appropriate. Either to free up memory or to speed up the step time.