57
57
PAD_SEQUENCE_TO_MULTIPLE_OF = int (os .environ .get ("PAD_SEQUENCE_TO_MULTIPLE_OF" , 256 ))
58
58
CHUNK_SIZES = [1 , 2 , 4 , 8 , 16 , 32 , 64 , 128 , 256 , 512 , 1024 , 2048 ]
59
59
LAZY_MODE = int (os .environ .get ("PT_HPU_LAZY_MODE" , 1 ))
60
- BATCH_BUCKET_SIZE = int (os .environ .get ("BATCH_BUCKET_SIZE" , 8 ))
61
- PREFILL_BATCH_BUCKET_SIZE = int (os .environ .get ("PREFILL_BATCH_BUCKET_SIZE" , 2 ))
60
+ BATCH_SIZE_EXPONENT_BASE = int (os .environ .get ("BATCH_SIZE_EXPONENT_BASE" , 2 ))
62
61
MAX_BATCH_SIZE = (
63
62
int (os .environ .get ("MAX_BATCH_SIZE" ))
64
63
if os .environ .get ("MAX_BATCH_SIZE" ) is not None
@@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
74
73
)
75
74
76
75
77
- def round_up (number , k ):
76
+ def round_up_seq (number , k ):
78
77
return (number + k - 1 ) // k * k
79
78
80
79
80
+ def round_up_batch (number ):
81
+ return BATCH_SIZE_EXPONENT_BASE ** (
82
+ math .ceil (math .log (number , BATCH_SIZE_EXPONENT_BASE ))
83
+ )
84
+
85
+
81
86
def to_tensor_indices (indices , device ):
82
87
return torch .tensor (indices , dtype = torch .long , device = device )
83
88
@@ -399,7 +404,7 @@ def recombine(
399
404
400
405
total_requests = sum (len (b ) for b in batches )
401
406
new_bs = total_requests
402
- new_bs = round_up (total_requests , BATCH_BUCKET_SIZE )
407
+ new_bs = round_up_batch (total_requests )
403
408
404
409
batch_id = batches [0 ].batch_id
405
410
device = batches [0 ].input_ids .device
@@ -540,7 +545,7 @@ def from_pb(
540
545
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
541
546
# this means that we cannot shift inputs to the left after a long input sequence
542
547
# was filtered out
543
- new_bs = round_up (len (requests ), PREFILL_BATCH_BUCKET_SIZE )
548
+ new_bs = round_up_batch (len (requests ))
544
549
missing_inputs = new_bs - len (inputs )
545
550
dummy_inputs = ["?" ] * missing_inputs
546
551
parameters = [r .parameters for r in pb .requests ]
@@ -572,7 +577,7 @@ def from_pb(
572
577
assert (
573
578
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
574
579
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
575
- rounded_seq_len = round_up (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
580
+ rounded_seq_len = round_up_seq (input_len + 1 , PAD_SEQUENCE_TO_MULTIPLE_OF )
576
581
if rounded_seq_len <= max_input_length :
577
582
bucket_size = rounded_seq_len - 1
578
583
else :
@@ -1068,10 +1073,10 @@ def generate_token(
1068
1073
if (
1069
1074
self .enable_hpu_graph
1070
1075
and self .limit_hpu_graph
1071
- and round_up (batch .batch_size , BATCH_BUCKET_SIZE ) != self .prev_bs
1076
+ and round_up_batch (batch .batch_size ) != self .prev_bs
1072
1077
):
1073
1078
self .model .clear_cache ()
1074
- self .prev_bs = round_up (batch .batch_size , BATCH_BUCKET_SIZE )
1079
+ self .prev_bs = round_up_batch (batch .batch_size )
1075
1080
dbg_trace (
1076
1081
scenario ,
1077
1082
f"bs:{ batch .batch_size } num_reqs:{ len (batch .requests )} seq_len:{ batch .seq_length } padding:{ batch .right_padding } " ,
@@ -1325,15 +1330,14 @@ def warmup(
1325
1330
1326
1331
# Warmup prefill batch_size
1327
1332
max_input_tokens = request .max_input_tokens
1333
+ max_exp = math .ceil (math .log (max_prefill_batch_size , BATCH_SIZE_EXPONENT_BASE ))
1328
1334
prefill_batch_size_list = [
1329
- batch
1330
- for batch in range (
1331
- PREFILL_BATCH_BUCKET_SIZE ,
1332
- max_prefill_batch_size ,
1333
- PREFILL_BATCH_BUCKET_SIZE ,
1335
+ BATCH_SIZE_EXPONENT_BASE ** exp
1336
+ for exp in range (
1337
+ 0 ,
1338
+ max_exp + 1 ,
1334
1339
)
1335
1340
]
1336
- prefill_batch_size_list .append (max_prefill_batch_size )
1337
1341
prefill_seqlen_list = [
1338
1342
seq
1339
1343
for seq in range (
@@ -1370,12 +1374,10 @@ def warmup(
1370
1374
)
1371
1375
1372
1376
max_decode_batch_size = math .floor (MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS )
1373
- max_decode_batch_size = round_up ( max_decode_batch_size , BATCH_BUCKET_SIZE )
1377
+ max_exp = math . ceil ( math . log ( max_decode_batch_size , BATCH_SIZE_EXPONENT_BASE ) )
1374
1378
decode_batch_size_list = [
1375
- i
1376
- for i in range (BATCH_BUCKET_SIZE , max_decode_batch_size , BATCH_BUCKET_SIZE )
1379
+ BATCH_SIZE_EXPONENT_BASE ** exp for exp in range (0 , max_exp + 1 )
1377
1380
]
1378
- decode_batch_size_list .append (max_decode_batch_size )
1379
1381
decode_batch_size_list .sort (reverse = True )
1380
1382
1381
1383
try :
0 commit comments