Expert Sharding

MoE opens the door to a new sharding strategy: Expert Sharding. We assign experts to different devices, after we compute the expert choices, we send the tokens to the correct devices where we perform the computations. We then send them back to where they came from.

It is quite different to other sharding strategies because each device will see a different set of tokens and will compute a different amount of tokens.

The only difference with the basic single device implementation is that instead of appending to the local list, we send the tokens to the device who hold the correct experts. The difficulty is that the amount we send to each device is dynamic. Typically, we will need to communicate to each other device in the mesh how many tokens we are sending them. This means we need to All-To-All some metadata about how many tokens were sent from each device concurrently to the tokens we are sending.

image

Let's rewrite the code using the same Pseudo API using the logic you would find in an MoE kernel but with simplified APIs.

...

def forward(self, x, router_weights, w0, w1, k: int = 2):
    """
    x: (b_per_device, t, d)
    router_weights: (d, e_total)
    w0: (e_per_device, d, f)
    w1: (e_per_device, f, d)
    """
    num_experts_total = router_weights.shape[1]
    num_experts_per_device = w0.shape[0]

    # -- SAME AS SINGLE DEVICE --
    b, t, d = x.shape
    x = x.reshape(-1, x.shape[-1])
    expert_choices = np.einsum('Bd,de->Be', x, router_weights)
    router_probs = softmax(expert_choices)
    top_k_indices = expert_choices.argsort(axis=-1)[..., -k:]
    top_k_probs = np.take_along_axis(router_probs, top_k_indices, axis=-1)
    top_k_probs /= top_k_probs.sum(axis=1)
    # ---------------------------

    # Allocate a buffer for the other chips to write to
    # We allocate for the worst case scenario where each token
    # goes to each expert on the current device
    buffer_shape = (num_experts_per_device, self.num_devices() * x.shape[0], x.shape[-1])
    # (e_per_device, bt_global, d)
    target_buffer = np.zeros(buffer_shape, dtype=np.float16)
    # (num_devices, e_per_device)
    token_origin = np.zeros((self.num_devices(), num_experts_per_device))
    futures = []
    # Synchronize all chips to ensure they are at the same point
    # so it's safe to write to the buffer (address is allocated everywhere.)
    self.barrier()

    # How many tokens we are sending to each expert
    sending_amount = np.zeros((self.num_devices(), num_experts_per_device))

    # Send the tokens to the other devices
    # At the same time, collect the metadata to send to all other devices
    for token_idx, expert_ids in enumerate(top_k_indices):
        for k_rank, expert_id in enumerate(expert_ids):
            # Device that owns the expert
            target_device = expert_id // num_experts_per_device
            # Local expert idx
            target_expert_id = expert_id % num_experts_per_device
            # Where to send the token such that it has a unique destination 
            # into the target buffer.
            # This ensures there are no collisions in between devices.
            global_token_id = self.my_id() * b * t + sending_amount[target_device, target_expert_id]
            # We write x[token_idx] into
            # target_buffer[target_expert_id, global_token_id, ...]
            # on device target_device.
            future = self.send_async(src=x[token_idx][None, ...],
                        dst=target_buffer[target_expert_id, global_token_id, ...],
                        target_device_id=target_device)
            futures.append(future)
            sending_amount[target_device, target_expert_id] += 1

    # Communicate to all chips how many tokens we sent them for each expert they hold
    for target_idx in range(self.num_devices()):
        future = self.send_async(
            src=sending_amount[target_idx, :],
            dst=token_origin[self.my_id(), :],
            target_device_id=target_idx
        )
        futures.append(future)
    # Wait until all communications are over
    for future in futures:
        future.wait()
    
    # At this point we have collected all the tokens we need into target_buffer.
    # We also know how many tokens were sent by each device.
    # Most of the buffer is padding, let's extract the tokens next to each other
    # to leverage tensor cores.
    # (e_per_device, bt_global, d)
    ordered_buffer = np.zeros_like(target_buffer)
    current_offset_per_expert = np.zeros((num_experts_per_device,))

    for orig_idx in range(self.num_devices()):
        orig_start_idx = orig_idx * b * t
        for local_expert_idx in range(num_experts_per_device):
            # How many tokens came from device orig_idx for expert local_expert_idx
            count = token_origin[orig_idx, local_expert_idx]
            # Slice target_buffer
            to_copy = target_buffer[local_expert_idx, orig_start_idx:orig_start_idx+count]
            # Where is the current offset
            offset = current_offset_per_expert[local_expert_idx]
            # Copy at the offset
            ordered_buffer[local_expert_idx, offset:offset+count] = to_copy
            # Update the offset
            current_offset_per_expert[local_expert_idx] += count

  
    # Placeholder for the final combined output
    # (e_per_device, bt_global, d)
    expert_outputs = np.zeros((num_experts_per_device, b*t*self.num_devices(), d))
    # Apply dot products to each expert
    for expert_idx in range(num_experts_per_device):
        # How many tokens for the given expert
        tokens_for_expert = current_offset_per_expert[expert_idx]
        # Slice the tokens so we don't process padding
        tokens = ordered_buffer[expert_idx, :tokens_for_expert]
        # w0
        y = np.einsum('Bd,df->Bf', tokens, w0[expert_idx])
        # relu
        y = np.maximum(y, 0)
        # linear
        y = np.einsum('Bf,fd->Bd', y, w1[expert_idx])
        expert_outputs[expert_idx, :tokens_for_expert] = y
    
    # Write back the output
    # (e_total, bt_local, d)
    collected_outputs = np.zeros((num_experts_total, b*t, d))
    futures = []
    self.barrier()

    for local_expert_idx in range(num_experts_per_device):
        global_expert_idx = self.my_id() * num_experts_per_device + local_expert_idx
        current_offset = 0
        for orig_idx in range(self.num_devices()):
            count = token_origin[orig_idx, local_expert_idx]
            future = self.send_async(
                src=expert_outputs[local_expert_idx, current_offset:current_offset+count],
                dst=collected_outputs[global_expert_idx],
                target_device_id=orig_idx
            )
            futures.append(future)
            current_offset += count

    for future in futures:
        future.wait()
    
    # We now have our outputs, but we need to reorder as (bt, k, d)
    expert_outputs = np.zeros((b*t, k, d))
    expert_offsets = np.zeros((num_experts_total,))
    for token_idx in range(b*t):
        for top_k in range(k):
            expert_idx = top_k_indices[token_idx, top_k]
            offset = expert_offsets[expert_idx]
            expert_outputs[token_idx, top_k] = collected_outputs[expert_idx, offset]
            expert_offsets[expert_idx] += 1

    
    # -- SAME AS SINGLE DEVICE --
    weighted_out = expert_outputs * top_k_probs[:, :, None]
    out = weighted_out.sum(axis=1)
    out = out.reshape(b, t, -1)
    return out + x.reshape(b, t, -1)

Other Ways to Shard

We can shard MoE using different approaches such as Megatron sharding. We can also use hybrid approach, like sharding the experts 8 ways and model 4 ways. This can be useful when we have more devices than experts, forcing us to shard other dimensions.