Commit e68642a
Make token group alignment size configurable (pytorch#1503)
## Summary
- For mxfp8, token group sizes must be multiples of "block_size" because
in the backward pass for `grad_weight = grad_output_t @ input`, the "M"
(token) dimension is the contracting dimension, and each token group is
a logically distinct subtensor, so we scale them separately. This means
token groups contracting dimension must be divisible by the mxfp8
block_size (default 32). Here is a diagram showing the problem:
https://www.internalfb.com/excalidraw/EX521879
- To solve this, this PR makes the token group M aligment configurable.
## Test plan
- Integration test with torchao passes:
pytorch/ao#2642
- Did manual test run with llama4 debug model using bf161 parent 9136225 commit e68642a
File tree
4 files changed
+56
-5
lines changed- torchtitan
- components/quantization
- config
- experiments/llama4/infra
4 files changed
+56
-5
lines changed| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| 13 | + | |
| 14 | + | |
| 15 | + | |
13 | 16 | | |
14 | 17 | | |
15 | 18 | | |
| |||
66 | 69 | | |
67 | 70 | | |
68 | 71 | | |
| 72 | + | |
| 73 | + | |
| 74 | + | |
| 75 | + | |
69 | 76 | | |
70 | 77 | | |
71 | 78 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
59 | 59 | | |
60 | 60 | | |
61 | 61 | | |
| 62 | + | |
| 63 | + | |
| 64 | + | |
| 65 | + | |
| 66 | + | |
| 67 | + | |
| 68 | + | |
| 69 | + | |
| 70 | + | |
| 71 | + | |
62 | 72 | | |
63 | 73 | | |
64 | 74 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
567 | 567 | | |
568 | 568 | | |
569 | 569 | | |
570 | | - | |
| 570 | + | |
571 | 571 | | |
572 | 572 | | |
573 | 573 | | |
574 | 574 | | |
575 | 575 | | |
| 576 | + | |
| 577 | + | |
| 578 | + | |
| 579 | + | |
| 580 | + | |
| 581 | + | |
| 582 | + | |
576 | 583 | | |
577 | 584 | | |
578 | 585 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
6 | 6 | | |
7 | 7 | | |
8 | 8 | | |
9 | | - | |
| 9 | + | |
10 | 10 | | |
11 | 11 | | |
12 | 12 | | |
| |||
24 | 24 | | |
25 | 25 | | |
26 | 26 | | |
| 27 | + | |
| 28 | + | |
| 29 | + | |
| 30 | + | |
| 31 | + | |
| 32 | + | |
| 33 | + | |
| 34 | + | |
| 35 | + | |
| 36 | + | |
| 37 | + | |
| 38 | + | |
| 39 | + | |
| 40 | + | |
| 41 | + | |
| 42 | + | |
| 43 | + | |
| 44 | + | |
| 45 | + | |
| 46 | + | |
| 47 | + | |
| 48 | + | |
| 49 | + | |
| 50 | + | |
| 51 | + | |
| 52 | + | |
| 53 | + | |
27 | 54 | | |
28 | 55 | | |
29 | 56 | | |
| |||
251 | 278 | | |
252 | 279 | | |
253 | 280 | | |
| 281 | + | |
254 | 282 | | |
255 | 283 | | |
256 | 284 | | |
| |||
264 | 292 | | |
265 | 293 | | |
266 | 294 | | |
267 | | - | |
268 | 295 | | |
269 | 296 | | |
270 | 297 | | |
| |||
274 | 301 | | |
275 | 302 | | |
276 | 303 | | |
277 | | - | |
278 | | - | |
| 304 | + | |
| 305 | + | |
279 | 306 | | |
280 | 307 | | |
281 | 308 | | |
| |||
0 commit comments