Practice Questions
A Practical Example
Let's explore a simple example of a post attention projection.
batch = 256
length = 1024
d_model = 4096
num_heads = 16
key_size = d_model // num_heads
# B L N K
x = torch.rand((batch, length, num_heads, key_size), dtype=torch.bfloat16, device='cuda')
# N K D
w = torch.rand((num_heads, key_size, d_model), dtype=torch.bfloat16, device='cuda')
out = torch.einsum('blnk,nkd->bld', x, w)
- How would pipelining work here (assuming multiple layers)?
- How would data parallelism work?
- What different ways can we implement tensor parallelism?
- How would we implement FSDP?