Skip to content

Commit 3d059f9

Browse files
authored
Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE (#3131)
* Gaudi: Use exponential growth to replace BATCH_BUCKET_SIZE Signed-off-by: yuanwu <[email protected]> * Remove debug modifications Signed-off-by: yuanwu <[email protected]> --------- Signed-off-by: yuanwu <[email protected]>
1 parent 0142550 commit 3d059f9

File tree

1 file changed

+20
-18
lines changed
  • backends/gaudi/server/text_generation_server/models

1 file changed

+20
-18
lines changed

backends/gaudi/server/text_generation_server/models/causal_lm.py

Lines changed: 20 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,7 @@
5757
PAD_SEQUENCE_TO_MULTIPLE_OF = int(os.environ.get("PAD_SEQUENCE_TO_MULTIPLE_OF", 256))
5858
CHUNK_SIZES = [1, 2, 4, 8, 16, 32, 64, 128, 256, 512, 1024, 2048]
5959
LAZY_MODE = int(os.environ.get("PT_HPU_LAZY_MODE", 1))
60-
BATCH_BUCKET_SIZE = int(os.environ.get("BATCH_BUCKET_SIZE", 8))
61-
PREFILL_BATCH_BUCKET_SIZE = int(os.environ.get("PREFILL_BATCH_BUCKET_SIZE", 2))
60+
BATCH_SIZE_EXPONENT_BASE = int(os.environ.get("BATCH_SIZE_EXPONENT_BASE", 2))
6261
MAX_BATCH_SIZE = (
6362
int(os.environ.get("MAX_BATCH_SIZE"))
6463
if os.environ.get("MAX_BATCH_SIZE") is not None
@@ -74,10 +73,16 @@ def torch_compile_for_eager(func):
7473
)
7574

7675

77-
def round_up(number, k):
76+
def round_up_seq(number, k):
7877
return (number + k - 1) // k * k
7978

8079

80+
def round_up_batch(number):
81+
return BATCH_SIZE_EXPONENT_BASE ** (
82+
math.ceil(math.log(number, BATCH_SIZE_EXPONENT_BASE))
83+
)
84+
85+
8186
def to_tensor_indices(indices, device):
8287
return torch.tensor(indices, dtype=torch.long, device=device)
8388

@@ -399,7 +404,7 @@ def recombine(
399404

400405
total_requests = sum(len(b) for b in batches)
401406
new_bs = total_requests
402-
new_bs = round_up(total_requests, BATCH_BUCKET_SIZE)
407+
new_bs = round_up_batch(total_requests)
403408

404409
batch_id = batches[0].batch_id
405410
device = batches[0].input_ids.device
@@ -540,7 +545,7 @@ def from_pb(
540545
# TODO: by tokenizing all inputs at once we loose information on actual input lengths
541546
# this means that we cannot shift inputs to the left after a long input sequence
542547
# was filtered out
543-
new_bs = round_up(len(requests), PREFILL_BATCH_BUCKET_SIZE)
548+
new_bs = round_up_batch(len(requests))
544549
missing_inputs = new_bs - len(inputs)
545550
dummy_inputs = ["?"] * missing_inputs
546551
parameters = [r.parameters for r in pb.requests]
@@ -572,7 +577,7 @@ def from_pb(
572577
assert (
573578
PAD_SEQUENCE_TO_MULTIPLE_OF <= max_input_length
574579
), "PAD_SEQUENCE_TO_MULTIPLE_OF cannot be higher than max_input_length"
575-
rounded_seq_len = round_up(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
580+
rounded_seq_len = round_up_seq(input_len + 1, PAD_SEQUENCE_TO_MULTIPLE_OF)
576581
if rounded_seq_len <= max_input_length:
577582
bucket_size = rounded_seq_len - 1
578583
else:
@@ -1068,10 +1073,10 @@ def generate_token(
10681073
if (
10691074
self.enable_hpu_graph
10701075
and self.limit_hpu_graph
1071-
and round_up(batch.batch_size, BATCH_BUCKET_SIZE) != self.prev_bs
1076+
and round_up_batch(batch.batch_size) != self.prev_bs
10721077
):
10731078
self.model.clear_cache()
1074-
self.prev_bs = round_up(batch.batch_size, BATCH_BUCKET_SIZE)
1079+
self.prev_bs = round_up_batch(batch.batch_size)
10751080
dbg_trace(
10761081
scenario,
10771082
f"bs:{batch.batch_size} num_reqs:{len(batch.requests)} seq_len:{batch.seq_length} padding:{batch.right_padding}",
@@ -1325,15 +1330,14 @@ def warmup(
13251330

13261331
# Warmup prefill batch_size
13271332
max_input_tokens = request.max_input_tokens
1333+
max_exp = math.ceil(math.log(max_prefill_batch_size, BATCH_SIZE_EXPONENT_BASE))
13281334
prefill_batch_size_list = [
1329-
batch
1330-
for batch in range(
1331-
PREFILL_BATCH_BUCKET_SIZE,
1332-
max_prefill_batch_size,
1333-
PREFILL_BATCH_BUCKET_SIZE,
1335+
BATCH_SIZE_EXPONENT_BASE**exp
1336+
for exp in range(
1337+
0,
1338+
max_exp + 1,
13341339
)
13351340
]
1336-
prefill_batch_size_list.append(max_prefill_batch_size)
13371341
prefill_seqlen_list = [
13381342
seq
13391343
for seq in range(
@@ -1370,12 +1374,10 @@ def warmup(
13701374
)
13711375

13721376
max_decode_batch_size = math.floor(MAX_BATCH_TOTAL_TOKENS / MAX_TOTAL_TOKENS)
1373-
max_decode_batch_size = round_up(max_decode_batch_size, BATCH_BUCKET_SIZE)
1377+
max_exp = math.ceil(math.log(max_decode_batch_size, BATCH_SIZE_EXPONENT_BASE))
13741378
decode_batch_size_list = [
1375-
i
1376-
for i in range(BATCH_BUCKET_SIZE, max_decode_batch_size, BATCH_BUCKET_SIZE)
1379+
BATCH_SIZE_EXPONENT_BASE**exp for exp in range(0, max_exp + 1)
13771380
]
1378-
decode_batch_size_list.append(max_decode_batch_size)
13791381
decode_batch_size_list.sort(reverse=True)
13801382

13811383
try:

0 commit comments

Comments
 (0)