Sharding Strategies (intro)

We can split computations and data across devices to reduce the flops and the amount of memory needed per device. There are many ways we can split data and computations across devices. This chapter explores common sharding strategies.

Pseudo API

We will illustrate this chapter using a fake distributed API over numpy. This API contains all the building blocks required to enable distributed computations. The code snippets would conceptually be run on all devices in parallel. Devices are assigned a device_id to differentiate.

The API revolves around an abstract class that we have to inherit from. We have to implement 3 methods ourselves load_checkpoint, forward and backward.

The API provides pre implemented methods device_id, num_devices, barrier, send, receive, all_gather, all_reduce, all_to_all.

We also have an inference_loop method that simply loops over a stream of inputs and streams back the outputs to an output stream.

We do not provide the loss function and simply assume it is separately provided by the optimizer API.

from abc import ABC, abstractmethod
import numpy as np
import numpy.typing as npt
from io import Reader, Writer

class ShardedEngine(ABC):
    
    @property
    def device_id(self) -> int:
        """The rank of the current device."""
        ...

    @property
    def num_devices(self) -> int:
        """Total number of devices in the cluster (World Size).""" npt.ArrayLike
        ...

    def barrier(self) -> None:
        """Blocks until all devices reach this line."""
        pass
    
    # --- Point-to-Point Communication ---
    def send(self, dest_id: int, arr: npt.ArrayLike) -> None:
        ...
    
    def receive(self, src_id: int) -> npt.ArrayLike:
        ...

    def send_async(self, src: npt.ArrayLike, dst: npt.ArrayLike, target_device_id: int) -> Future:
        ...

    # --- Collective Communication ---
    def all_gather(self, arr: npt.ArrayLike, axis: int = 0) -> npt.ArrayLike:
        """Concatenates arrays from all devices along the specified axis."""
        ...

    def all_reduce(self, arr: npt.ArrayLike, op: str = 'sum') -> npt.ArrayLike:
        """Reduces arrays from all devices (e.g., sum) and broadcasts the result."""
        ...

    def all_to_all(self, arr: npt.ArrayLike, axis: int = 0) -> npt.ArrayLike:
        """Scatters chunks of the array to different devices."""
        ...

    # --- Model Lifecycle ---
    @abstractmethod
    def load_checkpoint(self, params: dict[str, npt.ArrayLike]) -> None:
        ...

    @abstractmethod
    def forward(self, x: npt.ArrayLike) -> npt.ArrayLike:
        ...

    @abstractmethod
    def backward(self, grads: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
        ...
    
    def inference_loop(self, input_stream: Reader[npt.ArrayLike], output_stream: Writer[npt.ArrayLike]) -> None:
        for x in iter(input_stream):
            output_stream.write(self.forward(x))

Unsharded Example

Let's start with an unsharded example on a single device. We will start with a 2 layers model with a ReLU activation in between. \[\text{ReLU}(x W_0) W_1\]


def relu(x):
    return x * (x > 0)

class SingleDevice(ShardedEngine):

    def __init__(self, model_dim: int, hidden_dim: int):
        self.w0 = np.zeros((model_dim, hidden_dim), dtype=np.float32)
        self.w1 = np.zeros((hidden_dim, model_dim), dtype=np.float32)
        # Context tape to store activations for backward pass
        self.activations = []


    def load_checkpoint(self, params: dict[str, npt.ArrayLike]) -> None:
        # Load weights into local memory
        self.w0[...] = params['layer_0/weights'][...]
        self.w1[...] = params['layer_1/weights'][...]

    def forward(self, x: npt.ArrayLike) -> npt.ArrayLike:
        # 1. Save Input
        self.activations.append(x)
        
        # 2. Linear Layer 0
        # (Batch, Model) @ (Model, Hidden) -> (Batch, Hidden)
        z = np.einsum('bd,df->bf', x, self.w0)
        
        # 3. Activation
        self.activations.append(z) # Save pre-activation for the backward pass
        x = relu(z)
        
        # 4. Linear Layer 1
        # (Batch, Hidden) @ (Hidden, Model) -> (Batch, Model)
        out = np.einsum('bf,fd->bd', x, self.w1)
        return out

    def backward(self, grads: npt.ArrayLike) -> dict[str, npt.ArrayLike]:
        """
        grads: Incoming gradient dL/d(Output) of shape (Batch, Model_Dim)
        """
        # --- Backprop Layer 1 ---
        # Retrieve input to Layer 1 (Output of ReLU)
        # Shape: (Batch, Hidden)
        z = self.activations.pop()
        h_relu = relu(z)
        
        # dL/dW1 = h_relu.T @ grads
        w1_grad = np.einsum('bf,bd->fd', h_relu, grads)
        
        # Propagate gradient to h_relu: dL/dh = grads @ W1.T
        grads = np.einsum('bd,fd->bf', grads, self.w1)

        # --- Backprop ReLU ---
        # Apply derivative of ReLU: 1 if h > 0 else 0
        grads = grads * (z > 0)

        # --- Backprop Layer 0 ---
        # Retrieve input to Layer 0 (Original X)
        # Shape: (Batch, Model)
        x_input = self.activations.pop()
        
        # dL/dW0 = x_input.T @ grads
        w0_grad = np.einsum('bd,bf->df', x_input, grads)
        
        return {'layer_0/weights': w0_grad, 'layer_1/weights': w1_grad}