-
Notifications
You must be signed in to change notification settings - Fork 120
A8w8 asm codegen and tune #1161
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR adds ASM A8W8 bpreshuffle int8 codegen and integrates it into the tuning system. The changes extend the existing A8W8 bpreshuffle GEMM implementation to support int8 quantization alongside fp8, introducing new ASM kernels and updating the tuning framework to handle multiple quantization data types.
- Adds ASM int8 kernel configuration and codegen for A8W8 bpreshuffle GEMM
- Refactors tuning framework to support both fp8 and int8 quantization methods via q_dtype_w parameter
- Updates kernel selection logic and API signatures to support new ASM kernels
Reviewed Changes
Copilot reviewed 14 out of 17 changed files in this pull request and generated 4 comments.
Show a summary per file
File | Description |
---|---|
hsa/gfx942/i8gemm/i8gemm_bf16_perTokenI8.csv | New kernel configuration for int8 ASM kernels |
hsa/gfx942/i8gemm/codegen.py | Code generator for ASM i8gemm kernel configurations |
csrc/py_itfs_cu/asm_gemm_a8w8.cu | Major refactoring of ASM GEMM interface with kernel selection logic |
csrc/include/rocm_ops.hpp | Updated Python binding parameters for new ASM interface |
csrc/include/asm_gemm_a8w8.h | Updated function signature for new parameters |
csrc/ck_gemm_a8w8_bpreshuffle/gen_instances.py | Added filtering for int8 dtype in tuning |
csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py | Major refactoring to support both fp8 and int8 tuning |
csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.cu | Updated to support BFloat16 output |
csrc/ck_gemm_a8w8_bpreshuffle/README.md | Documentation updates for new q_dtype_w parameter |
aiter/utility/base_tuner.py | Base tuner improvements for result handling |
aiter/ops/gemm_op_a8w8.py | Updated GEMM operations to use new configuration system |
aiter/jit/optCompilerConfig.json | Added blob generation command for i8gemm |
aiter/configs/asm_a8w8_gemm.csv | Updated ASM kernel configurations |
aiter/configs/a8w8_bpreshuffle_untuned_gemm.csv | Added q_dtype_w column and int8 test cases |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
32768,7168,2304 | ||
32768,512,7168 | ||
32768,7168,256 | ||
M,N,K,,q_dtype_w |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an extra comma between K and q_dtype_w in the CSV header.
M,N,K,,q_dtype_w | |
M,N,K,q_dtype_w |
Copilot uses AI. Check for mistakes.
16384,7168,2304,dtypes.fp8 | ||
16384,512,7168,dtypes.fp8 | ||
16384,7168,256,dtypes.fp8 | ||
32768,4096,512, dtypes.fp8 |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There's an extra space before 'dtypes.fp8' that makes the format inconsistent with other rows.
32768,4096,512, dtypes.fp8 | |
32768,4096,512,dtypes.fp8 |
Copilot uses AI. Check for mistakes.
csrc/py_itfs_cu/asm_gemm_a8w8.cu
Outdated
|
||
if(config_map->empty()) | ||
{ | ||
TORCH_CHECK(false, __func__, " no kernel support a4w4 for this gpu arch"); |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Error message mentions 'a4w4' but should be 'a8w8' since this is the A8W8 GEMM function.
TORCH_CHECK(false, __func__, " no kernel support a4w4 for this gpu arch"); | |
TORCH_CHECK(false, __func__, " no kernel support a8w8 for this gpu arch"); |
Copilot uses AI. Check for mistakes.
csrc/include/rocm_ops.hpp
Outdated
m.def("gemm_a8w8_asm", \ | ||
&gemm_a8w8_asm, \ | ||
"Asm gemm a8w8 , weight should be shuffle to layout(32,16)", \ | ||
"Asm gemm a8w8 , weight should be shuffle to layout(16,16)", \ |
Copilot
AI
Oct 11, 2025
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Comment incorrectly states layout(16,16) but based on the code changes, it should be layout(32,16) for int8 kernels.
"Asm gemm a8w8 , weight should be shuffle to layout(16,16)", \ | |
"Asm gemm a8w8 , weight should be shuffle to layout(32,16)", \ |
Copilot uses AI. Check for mistakes.
Motivation
update a8w8 bpreshuffle asm code and add it to tune
Technical Details
Test Plan
python op_tests/test_gemm_a8w8.py
aiter/csrc/ck_gemm_a8w8_bpreshuffle/gemm_a8w8_bpreshuffle_tune.py
Test Result
Submission Checklist