@@ -40,6 +40,7 @@ def benchmark_config(
40
40
use_fp8_w8a8 : bool ,
41
41
use_int8_w8a16 : bool ,
42
42
num_iters : int = 100 ,
43
+ block_quant_shape : List [int ] = None ,
43
44
) -> float :
44
45
init_dtype = torch .float16 if use_fp8_w8a8 else dtype
45
46
x = torch .randn (num_tokens , hidden_size , dtype = dtype )
@@ -81,8 +82,24 @@ def benchmark_config(
81
82
dtype = torch .float32 )
82
83
w2_scale = torch .randn ((hidden_size , num_experts ), dtype = torch .float32 )
83
84
if use_fp8_w8a8 :
84
- w1_scale = torch .randn (num_experts , dtype = torch .float32 )
85
- w2_scale = torch .randn (num_experts , dtype = torch .float32 )
85
+ if block_quant_shape :
86
+ block_n , block_k = block_quant_shape [0 ], block_quant_shape [1 ]
87
+ E = num_experts
88
+ N = shard_intermediate_size // 2
89
+ K = hidden_size
90
+ factor_for_scale = 1e-2
91
+ n_tiles_w1 = (2 * N + block_n - 1 ) // block_n
92
+ n_tiles_w2 = (K + block_n - 1 ) // block_n
93
+ k_tiles_w1 = (K + block_k - 1 ) // block_k
94
+ k_tiles_w2 = (N + block_k - 1 ) // block_k
95
+ w1_scale = torch .rand ((E , n_tiles_w1 , k_tiles_w1 ),
96
+ dtype = torch .float32 ) * factor_for_scale
97
+ w2_scale = torch .rand ((E , n_tiles_w2 , k_tiles_w2 ),
98
+ dtype = torch .float32 ) * factor_for_scale
99
+ else :
100
+ w1_scale = torch .randn (num_experts , dtype = torch .float32 )
101
+ w2_scale = torch .randn (num_experts , dtype = torch .float32 )
102
+
86
103
a1_scale = torch .randn (1 , dtype = torch .float32 )
87
104
a2_scale = torch .randn (1 , dtype = torch .float32 )
88
105
@@ -111,6 +128,7 @@ def run():
111
128
w2_scale = w2_scale ,
112
129
a1_scale = a1_scale ,
113
130
a2_scale = a2_scale ,
131
+ block_shape = block_quant_shape ,
114
132
)
115
133
116
134
# JIT compilation & warmup
@@ -175,7 +193,8 @@ def get_rocm_tuning_space(use_fp16):
175
193
return param_ranges
176
194
177
195
178
- def get_configs_compute_bound (use_fp16 ) -> list [dict [str , int ]]:
196
+ def get_configs_compute_bound (use_fp16 ,
197
+ block_quant_shape ) -> list [dict [str , int ]]:
179
198
configs : list [BenchmarkConfig ] = []
180
199
181
200
if current_platform .is_rocm ():
@@ -204,17 +223,27 @@ def get_configs_compute_bound(use_fp16) -> list[dict[str, int]]:
204
223
for config_values in product (* values ):
205
224
config = dict (zip (keys , config_values ))
206
225
configs .append (config )
226
+
227
+ # Remove configs that are not compatible with fp8 block quantization
228
+ # BLOCK_SIZE_K must be a multiple of block_k
229
+ # BLOCK_SIZE_N must be a multiple of block_n
230
+ if block_quant_shape is not None and not use_fp16 :
231
+ block_n , block_k = block_quant_shape [0 ], block_quant_shape [1 ]
232
+ for config in configs [:]:
233
+ if config ["BLOCK_SIZE_K" ] % block_k != 0 or config [
234
+ "BLOCK_SIZE_N" ] % block_n != 0 :
235
+ configs .remove (config )
207
236
return configs
208
237
209
238
210
239
def prune_rocm_search_space (num_tokens , shard_intermediate_size , hidden_size ,
211
- search_space , is_fp16 ):
240
+ search_space , is_fp16 , topk ):
212
241
N1 , K1 = shard_intermediate_size , hidden_size
213
242
N2 , K2 = hidden_size , shard_intermediate_size // 2
214
- pruned_space_1 = prune_rocm_configs (num_tokens * 2 , N1 , K1 , search_space ,
215
- is_fp16 )
216
- pruned_space_2 = prune_rocm_configs (num_tokens * 2 , N2 , K2 , search_space ,
217
- is_fp16 )
243
+ pruned_space_1 = prune_rocm_configs (num_tokens * topk , N1 , K1 ,
244
+ search_space , is_fp16 )
245
+ pruned_space_2 = prune_rocm_configs (num_tokens * topk , N2 , K2 ,
246
+ search_space , is_fp16 )
218
247
search_space = merge_unique_dicts (pruned_space_1 , pruned_space_2 )
219
248
return search_space
220
249
@@ -372,6 +401,7 @@ def tune(
372
401
use_fp8_w8a8 : bool ,
373
402
use_int8_w8a16 : bool ,
374
403
search_space : list [dict [str , int ]],
404
+ block_quant_shape : list [int ],
375
405
) -> dict [str , int ]:
376
406
best_config = None
377
407
best_time = float ("inf" )
@@ -380,21 +410,23 @@ def tune(
380
410
search_space = prune_rocm_search_space (num_tokens ,
381
411
shard_intermediate_size ,
382
412
hidden_size , search_space ,
383
- is_fp16 )
413
+ is_fp16 , topk )
384
414
385
415
with torch .cuda .device (self .device_id ):
386
416
for config in tqdm (search_space ):
387
417
try :
388
- kernel_time = benchmark_config (config ,
389
- num_tokens ,
390
- num_experts ,
391
- shard_intermediate_size ,
392
- hidden_size ,
393
- topk ,
394
- dtype ,
395
- use_fp8_w8a8 ,
396
- use_int8_w8a16 ,
397
- num_iters = 20 )
418
+ kernel_time = benchmark_config (
419
+ config ,
420
+ num_tokens ,
421
+ num_experts ,
422
+ shard_intermediate_size ,
423
+ hidden_size ,
424
+ topk ,
425
+ dtype ,
426
+ use_fp8_w8a8 ,
427
+ use_int8_w8a16 ,
428
+ num_iters = 20 ,
429
+ block_quant_shape = block_quant_shape )
398
430
except triton .runtime .autotuner .OutOfResources :
399
431
# Some configurations may be invalid and fail to compile.
400
432
continue
@@ -436,16 +468,16 @@ def sort_config(config: BenchmarkConfig) -> BenchmarkConfig:
436
468
437
469
def save_configs (configs : dict [int , BenchmarkConfig ], num_experts : int ,
438
470
shard_intermediate_size : int , hidden_size : int , topk : int ,
439
- dtype : torch .dtype , use_fp8_w8a8 : bool ,
440
- use_int8_w8a16 : bool ) -> None :
471
+ dtype : torch .dtype , use_fp8_w8a8 : bool , use_int8_w8a16 : bool ,
472
+ block_quant_shape : List [ int ] ) -> None :
441
473
dtype_str = get_config_dtype_str (dtype ,
442
474
use_int8_w8a16 = use_int8_w8a16 ,
443
475
use_fp8_w8a8 = use_fp8_w8a8 )
444
476
445
477
# NOTE(woosuk): The current naming convention uses w2.shape[2], which
446
478
# is the intermediate size after silu_and_mul.
447
479
filename = get_config_file_name (num_experts , shard_intermediate_size // 2 ,
448
- dtype_str )
480
+ dtype_str , block_quant_shape )
449
481
450
482
print (f"Writing best config to { filename } ..." )
451
483
with open (filename , "w" ) as f :
@@ -455,7 +487,7 @@ def save_configs(configs: dict[int, BenchmarkConfig], num_experts: int,
455
487
456
488
def main (args : argparse .Namespace ):
457
489
print (args )
458
-
490
+ block_quant_shape = None
459
491
config = AutoConfig .from_pretrained (
460
492
args .model , trust_remote_code = args .trust_remote_code )
461
493
if config .architectures [0 ] == "DbrxForCausalLM" :
@@ -474,6 +506,7 @@ def main(args: argparse.Namespace):
474
506
topk = config .num_experts_per_tok
475
507
intermediate_size = config .moe_intermediate_size
476
508
shard_intermediate_size = 2 * intermediate_size // args .tp_size
509
+ block_quant_shape = config .quantization_config ['weight_block_size' ]
477
510
else :
478
511
# Default: Mixtral.
479
512
E = config .num_local_experts
@@ -511,27 +544,30 @@ def _distribute(method: str, inputs: list[Any]) -> list[Any]:
511
544
512
545
if args .tune :
513
546
is_fp16 = not (use_fp8_w8a8 or use_int8_w8a16 )
514
- search_space = get_configs_compute_bound (is_fp16 )
547
+ search_space = get_configs_compute_bound (is_fp16 , block_quant_shape )
515
548
print (f"Start tuning over { len (search_space )} configurations..." )
516
549
517
550
start = time .time ()
518
551
configs = _distribute (
519
- "tune" , [(batch_size , E , shard_intermediate_size , hidden_size ,
520
- topk , dtype , use_fp8_w8a8 , use_int8_w8a16 , search_space )
521
- for batch_size in batch_sizes ])
552
+ "tune" ,
553
+ [(batch_size , E , shard_intermediate_size , hidden_size , topk , dtype ,
554
+ use_fp8_w8a8 , use_int8_w8a16 , search_space , block_quant_shape )
555
+ for batch_size in batch_sizes ])
522
556
best_configs = {
523
557
M : sort_config (config )
524
558
for M , config in zip (batch_sizes , configs )
525
559
}
526
560
save_configs (best_configs , E , shard_intermediate_size , hidden_size ,
527
- topk , dtype , use_fp8_w8a8 , use_int8_w8a16 )
561
+ topk , dtype , use_fp8_w8a8 , use_int8_w8a16 ,
562
+ block_quant_shape )
528
563
end = time .time ()
529
564
print (f"Tuning took { end - start :.2f} seconds" )
530
565
else :
531
566
outputs = _distribute (
532
- "benchmark" , [(batch_size , E , shard_intermediate_size , hidden_size ,
533
- topk , dtype , use_fp8_w8a8 , use_int8_w8a16 )
534
- for batch_size in batch_sizes ])
567
+ "benchmark" ,
568
+ [(batch_size , E , shard_intermediate_size , hidden_size , topk , dtype ,
569
+ use_fp8_w8a8 , use_int8_w8a16 , block_quant_shape )
570
+ for batch_size in batch_sizes ])
535
571
536
572
for batch_size , (config , kernel_time ) in zip (batch_sizes , outputs ):
537
573
print (f"Batch size: { batch_size } , config: { config } " )
0 commit comments