Skip to content

Commit fa5a78e

Browse files
divakar-amdshreyankg
authored andcommitted
[core] moe fp8 block quant tuning support (vllm-project#14068)
Signed-off-by: Divakar Verma <[email protected]>
1 parent 6daa8d2 commit fa5a78e

File tree

2 files changed

+129
-57
lines changed

2 files changed

+129
-57
lines changed

benchmarks/kernels/benchmark_moe.py

Lines changed: 67 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def benchmark_config(
4040
use_fp8_w8a8: bool,
4141
use_int8_w8a16: bool,
4242
num_iters: int = 100,
43+
block_quant_shape: List[int] = None,
4344
) -> float:
4445
init_dtype = torch.float16 if use_fp8_w8a8 else dtype
4546
x = torch.randn(num_tokens, hidden_size, dtype=dtype)
@@ -81,8 +82,24 @@ def benchmark_config(
8182
dtype=torch.float32)
8283
w2_scale = torch.randn((hidden_size, num_experts), dtype=torch.float32)
8384
if use_fp8_w8a8:
84-
w1_scale = torch.randn(num_experts, dtype=torch.float32)
85-
w2_scale = torch.randn(num_experts, dtype=torch.float32)
85+
if block_quant_shape:
86+
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
87+
E = num_experts
88+
N = shard_intermediate_size // 2
89+
K = hidden_size
90+
factor_for_scale = 1e-2
91+
n_tiles_w1 = (2 * N + block_n - 1) // block_n
92+
n_tiles_w2 = (K + block_n - 1) // block_n
93+
k_tiles_w1 = (K + block_k - 1) // block_k
94+
k_tiles_w2 = (N + block_k - 1) // block_k
95+
w1_scale = torch.rand((E, n_tiles_w1, k_tiles_w1),
96+
dtype=torch.float32) * factor_for_scale
97+
w2_scale = torch.rand((E, n_tiles_w2, k_tiles_w2),
98+
dtype=torch.float32) * factor_for_scale
99+
else:
100+
w1_scale = torch.randn(num_experts, dtype=torch.float32)
101+
w2_scale = torch.randn(num_experts, dtype=torch.float32)
102+
86103
a1_scale = torch.randn(1, dtype=torch.float32)
87104
a2_scale = torch.randn(1, dtype=torch.float32)
88105

@@ -111,6 +128,7 @@ def run():
111128
w2_scale=w2_scale,
112129
a1_scale=a1_scale,
113130
a2_scale=a2_scale,
131+
block_shape=block_quant_shape,
114132
)
115133

116134
# JIT compilation & warmup
@@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
175193
return param_ranges
176194

177195

178-
def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
196+
def get_configs_compute_bound(use_fp16,
197+
block_quant_shape) -> list[dict[str, int]]:
179198
configs: list[BenchmarkConfig] = []
180199

181200
if current_platform.is_rocm():
@@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
204223
for config_values in product(*values):
205224
config = dict(zip(keys, config_values))
206225
configs.append(config)
226+
227+
# Remove configs that are not compatible with fp8 block quantization
228+
# BLOCK_SIZE_K must be a multiple of block_k
229+
# BLOCK_SIZE_N must be a multiple of block_n
230+
if block_quant_shape is not None and not use_fp16:
231+
block_n, block_k = block_quant_shape[0], block_quant_shape[1]
232+
for config in configs[:]:
233+
if config["BLOCK_SIZE_K"] % block_k != 0 or config[
234+
"BLOCK_SIZE_N"] % block_n != 0:
235+
configs.remove(config)
207236
return configs
208237

209238

210239
def prune_rocm_search_space(num_tokens, shard_intermediate_size, hidden_size,
211-
search_space, is_fp16):
240+
search_space, is_fp16, topk):
212241
N1, K1 = shard_intermediate_size, hidden_size
213242
N2, K2 = hidden_size, shard_intermediate_size // 2
214-
pruned_space_1 = prune_rocm_configs(num_tokens * 2, N1, K1, search_space,
215-
is_fp16)
216-
pruned_space_2 = prune_rocm_configs(num_tokens * 2, N2, K2, search_space,
217-
is_fp16)
243+
pruned_space_1 = prune_rocm_configs(num_tokens * topk, N1, K1,
244+
search_space, is_fp16)
245+
pruned_space_2 = prune_rocm_configs(num_tokens * topk, N2, K2,
246+
search_space, is_fp16)
218247
search_space = merge_unique_dicts(pruned_space_1, pruned_space_2)
219248
return search_space
220249

@@ -372,6 +401,7 @@ def tune(
372401
use_fp8_w8a8: bool,
373402
use_int8_w8a16: bool,
374403
search_space: list[dict[str, int]],
404+
block_quant_shape: list[int],
375405
) -> dict[str, int]:
376406
best_config = None
377407
best_time = float("inf")
@@ -380,21 +410,23 @@ def tune(
380410
search_space = prune_rocm_search_space(num_tokens,
381411
shard_intermediate_size,
382412
hidden_size, search_space,
383-
is_fp16)
413+
is_fp16, topk)
384414

385415
with torch.cuda.device(self.device_id):
386416
for config in tqdm(search_space):
387417
try:
388-
kernel_time = benchmark_config(config,
389-
num_tokens,
390-
num_experts,
391-
shard_intermediate_size,
392-
hidden_size,
393-
topk,
394-
dtype,
395-
use_fp8_w8a8,
396-
use_int8_w8a16,
397-
num_iters=20)
418+
kernel_time = benchmark_config(
419+
config,
420+
num_tokens,
421+
num_experts,
422+
shard_intermediate_size,
423+
hidden_size,
424+
topk,
425+
dtype,
426+
use_fp8_w8a8,
427+
use_int8_w8a16,
428+
num_iters=20,
429+
block_quant_shape=block_quant_shape)
398430
except triton.runtime.autotuner.OutOfResources:
399431
# Some configurations may be invalid and fail to compile.
400432
continue
@@ -436,16 +468,16 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
436468

437469
def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
438470
shard_intermediate_size: int, hidden_size: int, topk: int,
439-
dtype: torch.dtype, use_fp8_w8a8: bool,
440-
use_int8_w8a16: bool) -> None:
471+
dtype: torch.dtype, use_fp8_w8a8: bool, use_int8_w8a16: bool,
472+
block_quant_shape: List[int]) -> None:
441473
dtype_str = get_config_dtype_str(dtype,
442474
use_int8_w8a16=use_int8_w8a16,
443475
use_fp8_w8a8=use_fp8_w8a8)
444476

445477
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
446478
# is the intermediate size after silu_and_mul.
447479
filename = get_config_file_name(num_experts, shard_intermediate_size // 2,
448-
dtype_str)
480+
dtype_str, block_quant_shape)
449481

450482
print(f"Writing best config to {filename}...")
451483
with open(filename, "w") as f:
@@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
455487

456488
def main(args: argparse.Namespace):
457489
print(args)
458-
490+
block_quant_shape = None
459491
config = AutoConfig.from_pretrained(
460492
args.model, trust_remote_code=args.trust_remote_code)
461493
if config.architectures[0] == "DbrxForCausalLM":
@@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
474506
topk = config.num_experts_per_tok
475507
intermediate_size = config.moe_intermediate_size
476508
shard_intermediate_size = 2 * intermediate_size // args.tp_size
509+
block_quant_shape = config.quantization_config['weight_block_size']
477510
else:
478511
# Default: Mixtral.
479512
E = config.num_local_experts
@@ -511,27 +544,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
511544

512545
if args.tune:
513546
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16)
514-
search_space = get_configs_compute_bound(is_fp16)
547+
search_space = get_configs_compute_bound(is_fp16, block_quant_shape)
515548
print(f"Start tuning over {len(search_space)} configurations...")
516549

517550
start = time.time()
518551
configs = _distribute(
519-
"tune", [(batch_size, E, shard_intermediate_size, hidden_size,
520-
topk, dtype, use_fp8_w8a8, use_int8_w8a16, search_space)
521-
for batch_size in batch_sizes])
552+
"tune",
553+
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
554+
use_fp8_w8a8, use_int8_w8a16, search_space, block_quant_shape)
555+
for batch_size in batch_sizes])
522556
best_configs = {
523557
M: sort_config(config)
524558
for M, config in zip(batch_sizes, configs)
525559
}
526560
save_configs(best_configs, E, shard_intermediate_size, hidden_size,
527-
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
561+
topk, dtype, use_fp8_w8a8, use_int8_w8a16,
562+
block_quant_shape)
528563
end = time.time()
529564
print(f"Tuning took {end - start:.2f} seconds")
530565
else:
531566
outputs = _distribute(
532-
"benchmark", [(batch_size, E, shard_intermediate_size, hidden_size,
533-
topk, dtype, use_fp8_w8a8, use_int8_w8a16)
534-
for batch_size in batch_sizes])
567+
"benchmark",
568+
[(batch_size, E, shard_intermediate_size, hidden_size, topk, dtype,
569+
use_fp8_w8a8, use_int8_w8a16, block_quant_shape)
570+
for batch_size in batch_sizes])
535571

536572
for batch_size, (config, kernel_time) in zip(batch_sizes, outputs):
537573
print(f"Batch size: {batch_size}, config: {config}")

vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json

Lines changed: 62 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,28 @@
11
{
22
"1": {
33
"BLOCK_SIZE_M": 16,
4-
"BLOCK_SIZE_N": 32,
4+
"BLOCK_SIZE_N": 128,
55
"BLOCK_SIZE_K": 256,
66
"GROUP_SIZE_M": 1,
7-
"num_warps": 4,
7+
"num_warps": 8,
88
"num_stages": 2,
99
"waves_per_eu": 0
1010
},
1111
"2": {
12-
"BLOCK_SIZE_M": 32,
13-
"BLOCK_SIZE_N": 16,
12+
"BLOCK_SIZE_M": 16,
13+
"BLOCK_SIZE_N": 128,
1414
"BLOCK_SIZE_K": 256,
1515
"GROUP_SIZE_M": 1,
16-
"num_warps": 2,
16+
"num_warps": 8,
1717
"num_stages": 2,
1818
"waves_per_eu": 0
1919
},
2020
"4": {
2121
"BLOCK_SIZE_M": 16,
22-
"BLOCK_SIZE_N": 64,
22+
"BLOCK_SIZE_N": 128,
2323
"BLOCK_SIZE_K": 256,
2424
"GROUP_SIZE_M": 1,
25-
"num_warps": 4,
25+
"num_warps": 8,
2626
"num_stages": 2,
2727
"waves_per_eu": 0
2828
},
@@ -31,15 +31,15 @@
3131
"BLOCK_SIZE_N": 128,
3232
"BLOCK_SIZE_K": 128,
3333
"GROUP_SIZE_M": 1,
34-
"num_warps": 4,
34+
"num_warps": 8,
3535
"num_stages": 2,
3636
"waves_per_eu": 0
3737
},
3838
"16": {
3939
"BLOCK_SIZE_M": 16,
40-
"BLOCK_SIZE_N": 64,
40+
"BLOCK_SIZE_N": 128,
4141
"BLOCK_SIZE_K": 128,
42-
"GROUP_SIZE_M": 4,
42+
"GROUP_SIZE_M": 1,
4343
"num_warps": 2,
4444
"num_stages": 2,
4545
"waves_per_eu": 0
@@ -49,13 +49,13 @@
4949
"BLOCK_SIZE_N": 128,
5050
"BLOCK_SIZE_K": 128,
5151
"GROUP_SIZE_M": 1,
52-
"num_warps": 4,
52+
"num_warps": 2,
5353
"num_stages": 2,
5454
"waves_per_eu": 0
5555
},
5656
"32": {
5757
"BLOCK_SIZE_M": 16,
58-
"BLOCK_SIZE_N": 64,
58+
"BLOCK_SIZE_N": 128,
5959
"BLOCK_SIZE_K": 128,
6060
"GROUP_SIZE_M": 4,
6161
"num_warps": 2,
@@ -64,7 +64,7 @@
6464
},
6565
"48": {
6666
"BLOCK_SIZE_M": 16,
67-
"BLOCK_SIZE_N": 64,
67+
"BLOCK_SIZE_N": 128,
6868
"BLOCK_SIZE_K": 128,
6969
"GROUP_SIZE_M": 4,
7070
"num_warps": 2,
@@ -73,7 +73,7 @@
7373
},
7474
"64": {
7575
"BLOCK_SIZE_M": 16,
76-
"BLOCK_SIZE_N": 64,
76+
"BLOCK_SIZE_N": 128,
7777
"BLOCK_SIZE_K": 128,
7878
"GROUP_SIZE_M": 1,
7979
"num_warps": 2,
@@ -82,46 +82,82 @@
8282
},
8383
"96": {
8484
"BLOCK_SIZE_M": 16,
85-
"BLOCK_SIZE_N": 64,
85+
"BLOCK_SIZE_N": 128,
8686
"BLOCK_SIZE_K": 128,
87-
"GROUP_SIZE_M": 4,
88-
"num_warps": 4,
87+
"GROUP_SIZE_M": 8,
88+
"num_warps": 8,
8989
"num_stages": 2,
9090
"waves_per_eu": 0
9191
},
9292
"128": {
9393
"BLOCK_SIZE_M": 16,
94-
"BLOCK_SIZE_N": 64,
95-
"BLOCK_SIZE_K": 256,
96-
"GROUP_SIZE_M": 1,
97-
"num_warps": 2,
94+
"BLOCK_SIZE_N": 128,
95+
"BLOCK_SIZE_K": 128,
96+
"GROUP_SIZE_M": 4,
97+
"num_warps": 4,
9898
"num_stages": 2,
9999
"waves_per_eu": 0
100100
},
101101
"256": {
102102
"BLOCK_SIZE_M": 16,
103-
"BLOCK_SIZE_N": 64,
103+
"BLOCK_SIZE_N": 128,
104104
"BLOCK_SIZE_K": 128,
105-
"GROUP_SIZE_M": 4,
105+
"GROUP_SIZE_M": 8,
106106
"num_warps": 4,
107107
"num_stages": 2,
108108
"waves_per_eu": 0
109109
},
110110
"512": {
111111
"BLOCK_SIZE_M": 32,
112-
"BLOCK_SIZE_N": 256,
112+
"BLOCK_SIZE_N": 128,
113113
"BLOCK_SIZE_K": 128,
114114
"GROUP_SIZE_M": 8,
115-
"num_warps": 8,
115+
"num_warps": 4,
116116
"num_stages": 2,
117117
"waves_per_eu": 0
118118
},
119119
"1024": {
120120
"BLOCK_SIZE_M": 64,
121+
"BLOCK_SIZE_N": 128,
122+
"BLOCK_SIZE_K": 128,
123+
"GROUP_SIZE_M": 8,
124+
"num_warps": 2,
125+
"num_stages": 2,
126+
"waves_per_eu": 0
127+
},
128+
"1536": {
129+
"BLOCK_SIZE_M": 64,
130+
"BLOCK_SIZE_N": 128,
131+
"BLOCK_SIZE_K": 128,
132+
"GROUP_SIZE_M": 4,
133+
"num_warps": 2,
134+
"num_stages": 2,
135+
"waves_per_eu": 0
136+
},
137+
"2048": {
138+
"BLOCK_SIZE_M": 128,
121139
"BLOCK_SIZE_N": 256,
122140
"BLOCK_SIZE_K": 128,
123141
"GROUP_SIZE_M": 8,
124-
"num_warps": 8,
142+
"num_warps": 4,
143+
"num_stages": 2,
144+
"waves_per_eu": 0
145+
},
146+
"3072": {
147+
"BLOCK_SIZE_M": 128,
148+
"BLOCK_SIZE_N": 256,
149+
"BLOCK_SIZE_K": 128,
150+
"GROUP_SIZE_M": 8,
151+
"num_warps": 4,
152+
"num_stages": 2,
153+
"waves_per_eu": 0
154+
},
155+
"4096": {
156+
"BLOCK_SIZE_M": 128,
157+
"BLOCK_SIZE_N": 256,
158+
"BLOCK_SIZE_K": 128,
159+
"GROUP_SIZE_M": 4,
160+
"num_warps": 4,
125161
"num_stages": 2,
126162
"waves_per_eu": 0
127163
}

0 commit comments

Comments
 (0)