Skip to content

Conversation

@danielvegamyhre
Copy link
Contributor

@danielvegamyhre danielvegamyhre commented Aug 4, 2025

In pytorch/torchtitan#1503 the default TOKEN_GROUP_ALIGNMENT_SIZE_M was changed from 16 (required for fp8) to 8 (minimum for bf16). See PR description for details.

Thus, in our fp8 training tests, we need to set it to 16. This is required so that
each logically distinct gemm in the grouped gemm grad_weight = grad_output_t @ input
has the contraction dim be divisible by 16. 16 byte alignment is required for the slowest moving dim (stride 1), so 16 bytes / 1 byte per element in fp8 = 16 elements.

Test plan

Test: pytest test/prototype/moe_training/test_training.py

Error without change:

E       torch.AcceleratorError: CUDA error: device-side assert triggered
E       Search for `cudaErrorAssert' in https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html for more information.
E       Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

torchao/prototype/moe_training/scaled_grouped_mm.py:259: AcceleratorError
---------------------------------------------------------------------------------- Captured stderr call ----------------------------------------------------------------------------------
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [0,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [1,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [2,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
/pytorch/aten/src/ATen/native/cuda/GroupMMCommon.cuh:64: prepare_grouped_gemm_data: block: [0,0,0], thread: [3,0,0] Assertion `delta % align == 0 && "expected input tensor dynamic dimension byte size to be non-negative multiple of 16\n"` failed.
______________________________________________________________

With change, tests pass.

@pytorch-bot
Copy link

pytorch-bot bot commented Aug 4, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/2678

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures, 1 Cancelled Job

As of commit 3a87a56 with merge base 7dbc816 (image):

NEW FAILURES - The following jobs have failed:

CANCELLED JOB - The following job was cancelled. Please retry:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@danielvegamyhre danielvegamyhre added topic: not user facing Use this tag if you don't want this PR to show up in release notes and removed CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. labels Aug 4, 2025
@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Aug 4, 2025
@danielvegamyhre danielvegamyhre merged commit be40518 into main Aug 5, 2025
16 of 20 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. moe topic: not user facing Use this tag if you don't want this PR to show up in release notes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants