Skip to content

Commit 03998b7

Browse files
Make token group alignment size configurable (#1503)
## Summary - For mxfp8, token group sizes must be multiples of "block_size" because in the backward pass for `grad_weight = grad_output_t @ input`, the "M" (token) dimension is the contracting dimension, and each token group is a logically distinct subtensor, so we scale them separately. This means token groups contracting dimension must be divisible by the mxfp8 block_size (default 32). Here is a diagram showing the problem: https://www.internalfb.com/excalidraw/EX521879 - To solve this, this PR makes the token group M aligment configurable. ## Test plan - Integration test with torchao passes: pytorch/ao#2642 - Did manual test run with llama4 debug model using bf16
1 parent cf30b29 commit 03998b7

File tree

5 files changed

+75
-76
lines changed

5 files changed

+75
-76
lines changed

torchtitan/components/quantization/float8.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,9 @@
1010

1111
from torchtitan.config.job_config import Float8, JobConfig
1212
from torchtitan.distributed import ParallelDims
13+
from torchtitan.experiments.llama4.infra.expert_parallel import (
14+
set_token_group_alignment_size_m,
15+
)
1316
from torchtitan.protocols.model_converter import (
1417
ModelConverter,
1518
register_model_converter,
@@ -66,6 +69,10 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
6669
job_config.parallelism.context_parallel_degree == 1
6770
), "Float8 MoE training prototype does not yet support context parallelism"
6871

72+
# For fp8 grouped GEMM, token group sizes must be multiples of 16
73+
# (16 byte alignment / 1 byte per elem = 16 elements)
74+
set_token_group_alignment_size_m(16)
75+
6976
if float8_config.recipe_name is not None:
7077
assert not float8_config.enable_fsdp_float8_all_gather, (
7178
"using `float8_config.enable_fsdp_float8_all_gather` together "

torchtitan/components/quantization/mx.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,16 @@ def __init__(self, job_config: JobConfig, parallel_dims: ParallelDims):
5959
and job_config.parallelism.tensor_parallel_degree > 1
6060
), "TP not yet supported with torch.compile for mxfp8"
6161

62+
# For MoE training with mxfp8, token group sizes must be multiples of 32
63+
if job_config.mx.moe_fqns_prototype:
64+
from torchtitan.experiments.llama4.infra.expert_parallel import (
65+
set_token_group_alignment_size,
66+
)
67+
68+
mxfp8_block_size = 32
69+
set_token_group_alignment_size(mxfp8_block_size)
70+
logger.info(f"Setting token group alignment size to {mxfp8_block_size}")
71+
6272
# Configure MXFP8
6373
from torchao.prototype.mx_formats.config import (
6474
MXFP8Dim1CastKernelChoice,

torchtitan/config/job_config.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -567,12 +567,19 @@ class MX:
567567

568568
filter_fqns: list[str] = field(default_factory=lambda: ["output"])
569569
"""
570-
Comma-separated list of fully qualified names of modules to skip applying mxfloat8 training to.
570+
Comma-separated list of fully qualified names of modules to skip applying mxfp8 training to.
571571
nn.Linear modules with any dim size not divisible by 16 are also always skipped due to hardware requirements.
572572
By default we always skip the output layer.
573573
Example: --mx.filter_fqns "attention.wq,attention.wk,attention.wv,output"
574574
"""
575575

576+
moe_fqns_prototype: list[str] | str = field(default_factory=list)
577+
"""
578+
Comma-separated list of fully qualified names of MoE modules to apply mxfp8 training to.
579+
This is a prototype feature that requires the torchao nightly build.
580+
Example: --mx.moe_fqns_prototype="experts"
581+
"""
582+
576583

577584
@dataclass
578585
class Comm:

torchtitan/experiments/llama4/infra/expert_parallel.py

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77

88
from functools import partial
9-
from typing import Callable
9+
from typing import Callable, Literal
1010

1111
import torch
1212
import torch.distributed as dist
@@ -24,6 +24,33 @@
2424
from torch.distributed.tensor.placement_types import Placement
2525

2626

27+
TOKEN_GROUP_ALIGN_SIZE_M = 8
28+
ValidTokenGroupAlignmentSize = Literal[8, 16, 32]
29+
30+
31+
def set_token_group_alignment_size_m(
32+
alignment_size: ValidTokenGroupAlignmentSize,
33+
) -> None:
34+
"""
35+
Set the token group alignment size for token groups in MoE. This is implemented by
36+
padding each token group size to the next multiple of TOKEN_GROUP_ALIGN_SIZE_M.
37+
38+
Valid values are: 8, 16, or 32.
39+
Different values are needed for different cases:
40+
41+
* For bf16, 8 is enough (16 byte alignment / 2 bytes per elem = 8 elements).
42+
* For fp8, 16 byte alignment / 1 byte per elem = 16 elements.
43+
* For mxfp8, we need 32 (or block_size) because scaling block size is (1 x 32),
44+
so when doing per-token-group quantization on each logically distinct subtensor,
45+
we need to ensure the contracting dim is divisible by block_size.
46+
In the backward pass, grad_weight = (grad_output_t @ input).t() has gemm dims
47+
of (N, M) @ (M, K) so M is the contracting dim, and group offsets are along M,
48+
so we need 32 element alignment.
49+
"""
50+
global TOKEN_GROUP_ALIGN_SIZE_M
51+
TOKEN_GROUP_ALIGN_SIZE_M = alignment_size
52+
53+
2754
# implementation of Tensor Parallel for the GroupedExperts in MoE
2855
class TensorParallel(ParallelStyle):
2956
def _partition_fn(self, name, module, device_mesh):
@@ -245,26 +272,24 @@ def expert_parallel(func: Callable) -> Callable:
245272
"""
246273

247274
def wrapper(
248-
w1: torch.Tensor,
275+
w13: torch.Tensor,
249276
w2: torch.Tensor,
250-
w3: torch.Tensor,
251277
x: torch.Tensor,
252278
num_tokens_per_expert: torch.Tensor | None = None,
253279
) -> torch.Tensor:
254-
if isinstance(w1, DTensor):
255-
w1 = w1.to_local()
280+
global TOKEN_GROUP_ALIGN_SIZE_M
281+
if isinstance(w13, DTensor):
282+
w13 = w13.to_local()
256283
w2 = w2.to_local()
257-
w3 = w3.to_local()
258284

259285
if num_tokens_per_expert is not None:
260286
from torchtitan.experiments.kernels.moe.indices import (
261287
generate_permute_indices,
262288
)
263289

264-
experts_per_ep_rank = w1.shape[0]
290+
experts_per_ep_rank = w13.shape[0]
265291
num_ep_ranks = num_tokens_per_expert.shape[0] // experts_per_ep_rank
266292

267-
ALIGN_SIZE_M = 16
268293
with torch.no_grad():
269294
(
270295
permuted_indices,
@@ -274,15 +299,15 @@ def wrapper(
274299
num_tokens_per_expert,
275300
experts_per_ep_rank,
276301
num_ep_ranks,
277-
x.shape[0] + experts_per_ep_rank * ALIGN_SIZE_M,
278-
ALIGN_SIZE_M,
302+
x.shape[0] + experts_per_ep_rank * TOKEN_GROUP_ALIGN_SIZE_M,
303+
TOKEN_GROUP_ALIGN_SIZE_M,
279304
)
280305

281306
x = torch.vstack((x, x.new_zeros((x.shape[-1]))))
282307
input_shape = x.shape
283308
x = x[permuted_indices, :]
284309

285-
out = func(w1, w2, w3, x, num_tokens_per_expert)
310+
out = func(w13, w2, x, num_tokens_per_expert)
286311

287312
if num_tokens_per_expert is not None:
288313
out_unpermuted = out.new_empty(input_shape)

torchtitan/experiments/llama4/model/moe.py

Lines changed: 14 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -23,76 +23,27 @@ def __init__(
2323
):
2424
super().__init__()
2525
self.num_experts = num_experts
26-
self.w1 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
27-
self.w2 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim))
28-
self.w3 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
26+
# Combine w1 and w3 into a single tensor to perform so we can combine
27+
# `x @ w1` and `x @ w3` into a single grouped mm.
28+
self.w13 = nn.Parameter(torch.empty(num_experts, hidden_dim, dim * 2))
29+
self.w2 = nn.Parameter(torch.empty(num_experts, dim, hidden_dim))
2930
self.use_grouped_mm = use_grouped_mm
3031

3132
def forward(
3233
self,
3334
x: torch.Tensor,
3435
num_tokens_per_expert: torch.Tensor | None = None,
3536
) -> torch.Tensor:
36-
if self.use_grouped_mm:
37-
return GroupedExperts._run_experts_grouped_mm(
38-
self.w1, self.w2, self.w3, x, num_tokens_per_expert
39-
)
40-
else:
41-
return GroupedExperts._run_experts_for_loop(
42-
self.w1, self.w2, self.w3, x, num_tokens_per_expert
43-
)
44-
45-
# TODO: keeping this for-loop implementation for comparison
46-
# and readability, may remove later
47-
@expert_parallel
48-
@staticmethod
49-
def _run_experts_for_loop(
50-
w1: torch.Tensor,
51-
w2: torch.Tensor,
52-
w3: torch.Tensor,
53-
x: torch.Tensor,
54-
num_tokens_per_expert: torch.Tensor | None = None,
55-
) -> torch.Tensor:
56-
if num_tokens_per_expert is not None:
57-
# NOTE: this would incur a synchronization between device and host
58-
num_tokens_per_expert = num_tokens_per_expert.tolist()
59-
60-
# side-effect code due to the usage of generate_permute_indices
61-
num_padding = x.shape[0] - sum(num_tokens_per_expert)
62-
63-
# a tuple of tensors indexed by experts
64-
# each with shape (tokens_per_expert(varying), dim)
65-
x = torch.split(
66-
x[: sum(num_tokens_per_expert)],
67-
split_size_or_sections=num_tokens_per_expert,
68-
dim=0,
69-
)
70-
out_experts_splits = []
71-
for expert_idx, x_expert in enumerate(x):
72-
h = F.silu(torch.matmul(x_expert, w1[expert_idx]))
73-
h = h * torch.matmul(x_expert, w3[expert_idx])
74-
h = torch.matmul(h, w2[expert_idx])
75-
# h shape (tokens_per_expert(varying), dim)
76-
out_experts_splits.append(h)
77-
out = torch.cat(out_experts_splits, dim=0)
78-
79-
# side-effect code due to the usage of generate_permute_indices
80-
out = torch.vstack((out, out.new_zeros((num_padding, out.shape[-1]))))
81-
else:
82-
# x shape (num_experts, tokens_per_expert, dim)
83-
h = F.silu(torch.bmm(x, w1))
84-
h = h * torch.bmm(x, w3)
85-
# out shape (num_experts, tokens_per_expert, dim)
86-
out = torch.bmm(h, w2)
37+
return GroupedExperts._run_experts_grouped_mm(
38+
self.w13, self.w2, x, num_tokens_per_expert
39+
)
8740

88-
return out
8941

9042
@expert_parallel
9143
@staticmethod
9244
def _run_experts_grouped_mm(
93-
w1: torch.Tensor,
45+
w13: torch.Tensor,
9446
w2: torch.Tensor,
95-
w3: torch.Tensor,
9647
x: torch.Tensor,
9748
num_tokens_per_expert: torch.Tensor | None = None,
9849
) -> torch.Tensor:
@@ -105,16 +56,14 @@ def _run_experts_grouped_mm(
10556
# fall back to regular bmm between 3D tensors
10657
assert x.dim() == 3
10758

108-
h = F.silu(torch._grouped_mm(x.bfloat16(), w1.bfloat16(), offs=offsets))
109-
h = h * torch._grouped_mm(x.bfloat16(), w3.bfloat16(), offs=offsets)
110-
out = torch._grouped_mm(h, w2.bfloat16(), offs=offsets).type_as(x)
111-
59+
x1, x3 = torch._grouped_mm(x, w13.transpose(-2, -1), offs=offsets).chunk(2, dim=-1)
60+
y = F.silu(x1) * x3
61+
out = torch._grouped_mm(y, w2.transpose(-2, -1), offs=offsets).type_as(x)
11262
return out
11363

11464
def init_weights(self, init_std: float):
115-
nn.init.trunc_normal_(self.w1, mean=0.0, std=0.02)
65+
nn.init.trunc_normal_(self.w13, mean=0.0, std=0.02)
11666
nn.init.trunc_normal_(self.w2, mean=0.0, std=init_std)
117-
nn.init.trunc_normal_(self.w3, mean=0.0, std=init_std)
11867

11968

12069
class TokenChoiceTopKRouter(nn.Module):
@@ -299,7 +248,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
299248

300249
# shared expert
301250
if self.shared_expert is not None:
302-
out = self.shared_expert(x.reshape(1, bs * slen, dim)).reshape(
251+
out = self.shared_expert(x.reshape(1, bs * slen, dim))
252+
out = out.reshape(
303253
bs * slen, dim
304254
)
305255
else:

0 commit comments

Comments
 (0)