Skip to content

Commit 000e313

Browse files
authored
Refine warmup and upgrade to synapse AI 1.21.0 (#3234)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent d658b5d commit 000e313

File tree

7 files changed

+278
-76
lines changed

7 files changed

+278
-76
lines changed

Dockerfile_gaudi

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# Those arguments are required to build the image
2-
ARG HABANA_VERSION=1.20.0
2+
ARG HABANA_VERSION=1.21.0
33
ARG PYTORCH_VERSION=2.6.0
44

55
# Rust builder
@@ -62,6 +62,7 @@ ENV PREFIX_CACHING=0
6262
ENV PREFILL_CHUNKING=0
6363
ENV PT_HPU_LAZY_MODE=1
6464
ENV PT_HPU_WEIGHT_SHARING=0
65+
ENV VLLM_EXPONENTIAL_BUCKETING=true
6566

6667
# Text Generation Inference base env
6768
ENV HF_HOME=/data \

backends/gaudi/Makefile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@ mkfile_path := $(abspath $(lastword $(MAKEFILE_LIST)))
22
mkfile_dir := $(dir $(mkfile_path))
33
root_dir := ${mkfile_dir}/../..
44

5-
HABANA_VERSION := 1.20.0
5+
HABANA_VERSION := 1.21.0
66
PYTORCH_VERSION := 2.6.0
77

88
.PHONY: image run-local-dev-container install-dependencies install-server install-router install-launcher local-dev-install

backends/gaudi/server/text_generation_server/models/flash_causal_lm.py

Lines changed: 142 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
import habana_frameworks.torch as htorch
7777
import itertools
7878
from vllm_hpu_extension.bucketing.common import get_bucketing_context
79+
from vllm_hpu_extension.profiler import HabanaMemoryProfiler, format_bytes
7980

8081
tracer = trace.get_tracer(__name__)
8182

@@ -1357,6 +1358,8 @@ def __init__(
13571358
):
13581359
self.quantize = quantize
13591360
self.process_group, rank, world_size = initialize_torch_distributed()
1361+
if world_size > 1:
1362+
self.process_group_cpu = torch.distributed.new_group(backend="gloo")
13601363

13611364
device = torch.device("hpu")
13621365
dtype = torch.bfloat16 if dtype is None else dtype
@@ -1453,6 +1456,7 @@ def __init__(
14531456
self.limit_hpu_graph = (
14541457
os.environ.get("LIMIT_HPU_GRAPH", "false").lower() == "true"
14551458
)
1459+
self.skip_warmup = os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true"
14561460
self.max_seq_len_to_capture = 8192
14571461
super().__init__(
14581462
model_id=model_id,
@@ -1521,7 +1525,7 @@ def warmup(
15211525
# The warmup batch is the biggest batch we could ever receive
15221526
self.kv_cache = []
15231527
empty_cache()
1524-
1528+
self.graphed_buckets = set()
15251529
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
15261530
# Calculate the number of blocks that can be allocated with the free memory
15271531
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
@@ -1533,7 +1537,20 @@ def warmup(
15331537
cache_block_size = BLOCK_SIZE * self.num_kv_heads * self.head_size
15341538
cache_block_size = cache_block_size * 2
15351539
total_cache_size = self.num_layers * cache_block_size * dtype_size
1536-
1540+
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
1541+
self.mem_reserved = int(free_memory * (1 - MEMORY_FRACTION))
1542+
graph_reserved_mem = (
1543+
float(os.environ.get("TGI_GRAPH_RESERVED_MEM", "0.1"))
1544+
if htorch.utils.internal.is_lazy()
1545+
else 0
1546+
)
1547+
mem_used_from_graph = int(
1548+
(free_memory - self.mem_reserved) * graph_reserved_mem
1549+
)
1550+
log_master(
1551+
logger.info,
1552+
f"Free memory on device {self.device}: {format_bytes(free_memory)} used_for_graph: {format_bytes(mem_used_from_graph)} ratio {graph_reserved_mem} reserved_for_runtime: {format_bytes(self.mem_reserved)}",
1553+
)
15371554
try:
15381555
self.init_kv_cache(
15391556
batch.num_blocks,
@@ -1548,15 +1565,6 @@ def warmup(
15481565

15491566
num_tokens = batch.to_pb().current_tokens
15501567
synchronize(self.device)
1551-
free_memory = get_free_memory(
1552-
self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM
1553-
)
1554-
real_free_memory = get_free_memory(self.device, MEMORY_FRACTION)
1555-
log_master(
1556-
logger.debug,
1557-
f"Free memory {free_memory / 1e9:.2f}GB , (real: {real_free_memory / 1e9:.2f}GB",
1558-
)
1559-
15601568
_, _batch, _ = self.generate_token([batch])
15611569
except Exception:
15621570
raise RuntimeError(
@@ -1565,8 +1573,9 @@ def warmup(
15651573
)
15661574

15671575
synchronize(self.device)
1568-
free_memory = get_free_memory(self.device, MEMORY_FRACTION * TGI_WIGGLE_ROOM)
1569-
kv_memory = free_memory
1576+
free_memory = get_free_memory(self.device, TGI_WIGGLE_ROOM)
1577+
1578+
kv_memory = free_memory - self.mem_reserved - mem_used_from_graph
15701579
num_blocks = (
15711580
# Leave 5% for some wiggle room
15721581
int(kv_memory // total_cache_size)
@@ -1583,7 +1592,6 @@ def warmup(
15831592

15841593
self.kv_cache = []
15851594
empty_cache()
1586-
15871595
self.init_kv_cache(
15881596
num_blocks,
15891597
self.num_layers,
@@ -1595,11 +1603,16 @@ def warmup(
15951603
self.max_batch_prefill_tokens = get_max_prefill_tokens()
15961604
max_num_seqs = int(os.getenv("MAX_BATCH_SIZE"))
15971605
HPUBucketingContext = get_bucketing_context()
1598-
max_total_tokens_aligned = math.ceil(max_total_tokens / BLOCK_SIZE) * BLOCK_SIZE
1606+
# need to warmup one more step since block is allocated from 1
1607+
block_step = os.getenv("VLLM_DECODE_BLOCK_BUCKET_STEP", BLOCK_SIZE)
1608+
max_total_tokens_aligned = math.ceil(
1609+
max_total_tokens / BLOCK_SIZE
1610+
) * BLOCK_SIZE + math.ceil(block_step * BLOCK_SIZE / max_num_seqs)
15991611
model_max_length = self.tokenizer.model_max_length
16001612
max_position_embeddings = getattr(
16011613
self.config, "max_position_embeddings", model_max_length
16021614
)
1615+
16031616
self.bucketing_ctx = HPUBucketingContext(
16041617
max_num_seqs,
16051618
max_num_seqs, # self.max_num_prefill_seqs, #TODO
@@ -1610,31 +1623,75 @@ def warmup(
16101623
max_input_tokens,
16111624
max_total_tokens_aligned,
16121625
)
1613-
max_blocks = (
1614-
max(BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE) + 1
1626+
max_blocks = max(
1627+
BLOCK_SIZE, max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
16151628
)
16161629
self.bucketing_ctx.num_hpu_blocks = min(max_blocks, num_blocks)
1617-
if os.getenv("VLLM_SKIP_WARMUP", "false").lower() == "true":
1630+
synchronize(self.device)
1631+
if self.skip_warmup:
16181632
self.bucketing_ctx.generate_prompt_buckets()
16191633
self.bucketing_ctx.generate_decode_buckets(
16201634
self.bucketing_ctx.num_hpu_blocks
16211635
)
1622-
logger.info("skip warmup hpu graph, not recommmended")
1636+
log_master(
1637+
logger.info, "skip warmup hpu graph, not recommmended, may cause OOM"
1638+
)
16231639
del _batch, batch
16241640
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
1625-
16261641
self.warmup_hpu_graph(batch)
16271642
del _batch, batch
16281643

16291644
return int(num_blocks * BLOCK_SIZE), max_input_tokens, max_total_tokens
16301645

1631-
def bypass_hpu_graphs(self, prefill, max_seq_len_to_capture):
1632-
if self.limit_hpu_graph:
1633-
return prefill
1634-
else:
1635-
return prefill and max_seq_len_to_capture > self.max_seq_len_to_capture
1646+
def log_warmup(self, prefilling, i, max_i, batch_size, seq_len):
1647+
free_mem = format_bytes(HabanaMemoryProfiler.current_free_device_memory())
1648+
phase = "Prompt" if prefilling else "Decode"
1649+
dim = "seq_len" if prefilling else "num_blocks"
1650+
graphed_bucket = (batch_size, seq_len, prefilling)
1651+
bypass = graphed_bucket not in self.graphed_buckets
1652+
msg = (
1653+
f"[Warmup][{phase}][{i+1}/{max_i}] "
1654+
f"batch_size:{batch_size} "
1655+
f"{dim}:{seq_len} "
1656+
f"bypass:{bypass} "
1657+
f"free_mem:{free_mem}"
1658+
)
1659+
log_master(logger.info, msg)
1660+
1661+
def use_graphs(self, prefill, seq_len, batch_size):
1662+
if self.limit_hpu_graph and prefill:
1663+
return False
1664+
1665+
if self.skip_warmup:
1666+
return True
1667+
1668+
return (batch_size, seq_len, prefill) in self.graphed_buckets
1669+
1670+
def align_workers(self, value, op):
1671+
if self.world_size <= 1:
1672+
return value
1673+
value_t = torch.tensor(value, device="cpu")
1674+
torch.distributed.all_reduce(value_t, op=op, group=self.process_group_cpu)
1675+
return value_t.item()
16361676

16371677
def warmup_hpu_graph(self, batch):
1678+
prompt_graph_mem_ratio = float(os.environ.get("VLLM_GRAPH_PROMPT_RATIO", "0.3"))
1679+
free_mem = HabanaMemoryProfiler.current_free_device_memory()
1680+
graph_free_mem = free_mem - self.mem_reserved
1681+
graph_free_mem = self.align_workers(
1682+
graph_free_mem, torch.distributed.ReduceOp.MIN
1683+
)
1684+
prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
1685+
decode_available_memory = graph_free_mem - prompt_available_memory
1686+
msg = (
1687+
f"Using {format_bytes(graph_free_mem)}"
1688+
f"/{format_bytes(free_mem)} "
1689+
"of free device memory for HPUGraphs, "
1690+
f"{format_bytes(prompt_available_memory)} for prompt and "
1691+
f"{format_bytes(decode_available_memory)} for decode "
1692+
f"(VLLM_GRAPH_PROMPT_RATIO={prompt_graph_mem_ratio})"
1693+
)
1694+
log_master(logger.info, msg)
16381695
start_time = time.time()
16391696
warmup_shape_count = 0
16401697
warmup_times = 3
@@ -1646,15 +1703,34 @@ def ordering_function_min_tokens(b):
16461703
buckets = list(
16471704
sorted(self.bucketing_ctx.prompt_buckets, key=ordering_function_min_tokens)
16481705
)
1649-
1706+
total_batch_seq = 0.001
1707+
total_mem = 0
1708+
available_mem = prompt_available_memory
16501709
for i, (batch_size, seq_len) in enumerate(buckets):
16511710
if batch_size * seq_len > self.max_batch_prefill_tokens:
16521711
continue
1712+
# Graph memory usage is proportional to seq dimension in a batch
1713+
batch_seq = batch_size * seq_len
1714+
mem_estimate = batch_seq / total_batch_seq * total_mem
1715+
graphed_bucket = (batch_size, seq_len, True)
1716+
if not (
1717+
mem_estimate >= available_mem or batch_seq > self.max_seq_len_to_capture
1718+
):
1719+
if graphed_bucket not in self.graphed_buckets:
1720+
self.graphed_buckets.add(graphed_bucket)
16531721
warmup_shape_count += 1
1654-
log_master(logger.info, f"warmup prefill seq {seq_len} bs {batch_size}")
1655-
for index in range(warmup_times):
1656-
self.warmup_prefill(seq_len, batch_size, batch)
1657-
synchronize(self.device)
1722+
self.log_warmup(True, i, len(buckets), batch_size, seq_len)
1723+
with HabanaMemoryProfiler() as mem_prof:
1724+
for index in range(warmup_times):
1725+
self.warmup_prefill(seq_len, batch_size, batch)
1726+
synchronize(self.device)
1727+
used_mem = self.align_workers(
1728+
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
1729+
)
1730+
if graphed_bucket in self.graphed_buckets:
1731+
available_mem -= used_mem
1732+
total_mem += used_mem
1733+
total_batch_seq += batch_seq
16581734

16591735
def ordering_function_max_bs(b):
16601736
return (-b[0], b[1])
@@ -1663,16 +1739,34 @@ def ordering_function_max_bs(b):
16631739
buckets = list(
16641740
sorted(self.bucketing_ctx.decode_buckets, key=ordering_function_max_bs)
16651741
)
1742+
free_mem = HabanaMemoryProfiler.current_free_device_memory()
1743+
total_batch_seq = 0.001
1744+
total_mem = 0
1745+
available_mem = free_mem - self.mem_reserved
16661746
for i, (batch_size, block_num) in enumerate(buckets):
16671747
if batch_size > block_num:
16681748
continue
1749+
# Graph memory usage is proportional to seq dimension in a batch
1750+
batch_seq = batch_size
1751+
mem_estimate = batch_seq / total_batch_seq * total_mem
1752+
graphed_bucket = (batch_size, block_num, False)
1753+
if not mem_estimate >= available_mem:
1754+
if graphed_bucket not in self.graphed_buckets:
1755+
self.graphed_buckets.add(graphed_bucket)
16691756
warmup_shape_count += 1
1670-
log_master(
1671-
logger.info, f"warmup decode bs {batch_size} block_num {block_num}"
1757+
self.log_warmup(False, i, len(buckets), batch_size, block_num)
1758+
with HabanaMemoryProfiler() as mem_prof:
1759+
for index in range(warmup_times):
1760+
self.warmup_decode(batch_size, block_num, batch)
1761+
synchronize(self.device)
1762+
used_mem = self.align_workers(
1763+
mem_prof.consumed_device_memory, torch.distributed.ReduceOp.MAX
16721764
)
1673-
for index in range(warmup_times):
1674-
self.warmup_decode(batch_size, block_num, batch)
1675-
synchronize(self.device)
1765+
if graphed_bucket in self.graphed_buckets:
1766+
available_mem -= used_mem
1767+
total_mem += used_mem
1768+
total_batch_seq += batch_seq
1769+
16761770
log_master(
16771771
logger.info,
16781772
f"warmup hpu graph time {int(time.time() - start_time)}s warmup shape count {warmup_shape_count}",
@@ -1707,8 +1801,8 @@ def warmup_prefill(
17071801
lm_head_indices = input_lengths - 1
17081802
kwargs = {}
17091803
if htorch.utils.internal.is_lazy():
1710-
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
1711-
True, input_ids.shape[0]
1804+
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
1805+
True, prompt_len, batch_size
17121806
)
17131807

17141808
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
@@ -1762,7 +1856,9 @@ def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBat
17621856
slots_tensor = torch.tensor(slots, dtype=batch.slots.dtype)
17631857
kwargs = {}
17641858
if htorch.utils.internal.is_lazy():
1765-
kwargs["bypass_hpu_graphs"] = False
1859+
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
1860+
False, hpu_attention_meta.block_list.shape[0], batch_size
1861+
)
17661862
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
17671863
self.model.forward(
17681864
input_ids=_async_h2d_tensor_copy(input_ids),
@@ -1858,8 +1954,14 @@ def forward(
18581954

18591955
kwargs = {}
18601956
if htorch.utils.internal.is_lazy():
1861-
kwargs["bypass_hpu_graphs"] = self.bypass_hpu_graphs(
1862-
batch.prefilling, input_ids.shape[0]
1957+
batch_size = input_lengths.shape[0]
1958+
prompt_len = (
1959+
input_ids.shape[0] // batch_size
1960+
if batch.prefilling
1961+
else batch.hpu_attn_meta.block_list.shape[0]
1962+
)
1963+
kwargs["bypass_hpu_graphs"] = not self.use_graphs(
1964+
batch.prefilling, prompt_len, batch_size
18631965
)
18641966

18651967
logits, speculative_logits = self.model.forward(

0 commit comments

Comments
 (0)