76
76
import habana_frameworks .torch as htorch
77
77
import itertools
78
78
from vllm_hpu_extension .bucketing .common import get_bucketing_context
79
+ from vllm_hpu_extension .profiler import HabanaMemoryProfiler , format_bytes
79
80
80
81
tracer = trace .get_tracer (__name__ )
81
82
@@ -1357,6 +1358,8 @@ def __init__(
1357
1358
):
1358
1359
self .quantize = quantize
1359
1360
self .process_group , rank , world_size = initialize_torch_distributed ()
1361
+ if world_size > 1 :
1362
+ self .process_group_cpu = torch .distributed .new_group (backend = "gloo" )
1360
1363
1361
1364
device = torch .device ("hpu" )
1362
1365
dtype = torch .bfloat16 if dtype is None else dtype
@@ -1453,6 +1456,7 @@ def __init__(
1453
1456
self .limit_hpu_graph = (
1454
1457
os .environ .get ("LIMIT_HPU_GRAPH" , "false" ).lower () == "true"
1455
1458
)
1459
+ self .skip_warmup = os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true"
1456
1460
self .max_seq_len_to_capture = 8192
1457
1461
super ().__init__ (
1458
1462
model_id = model_id ,
@@ -1521,7 +1525,7 @@ def warmup(
1521
1525
# The warmup batch is the biggest batch we could ever receive
1522
1526
self .kv_cache = []
1523
1527
empty_cache ()
1524
-
1528
+ self . graphed_buckets = set ()
1525
1529
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
1526
1530
# Calculate the number of blocks that can be allocated with the free memory
1527
1531
dtype_size = torch .tensor ([], dtype = self .kv_cache_dtype ).element_size ()
@@ -1533,7 +1537,20 @@ def warmup(
1533
1537
cache_block_size = BLOCK_SIZE * self .num_kv_heads * self .head_size
1534
1538
cache_block_size = cache_block_size * 2
1535
1539
total_cache_size = self .num_layers * cache_block_size * dtype_size
1536
-
1540
+ free_memory = get_free_memory (self .device , TGI_WIGGLE_ROOM )
1541
+ self .mem_reserved = int (free_memory * (1 - MEMORY_FRACTION ))
1542
+ graph_reserved_mem = (
1543
+ float (os .environ .get ("TGI_GRAPH_RESERVED_MEM" , "0.1" ))
1544
+ if htorch .utils .internal .is_lazy ()
1545
+ else 0
1546
+ )
1547
+ mem_used_from_graph = int (
1548
+ (free_memory - self .mem_reserved ) * graph_reserved_mem
1549
+ )
1550
+ log_master (
1551
+ logger .info ,
1552
+ f"Free memory on device { self .device } : { format_bytes (free_memory )} used_for_graph: { format_bytes (mem_used_from_graph )} ratio { graph_reserved_mem } reserved_for_runtime: { format_bytes (self .mem_reserved )} " ,
1553
+ )
1537
1554
try :
1538
1555
self .init_kv_cache (
1539
1556
batch .num_blocks ,
@@ -1548,15 +1565,6 @@ def warmup(
1548
1565
1549
1566
num_tokens = batch .to_pb ().current_tokens
1550
1567
synchronize (self .device )
1551
- free_memory = get_free_memory (
1552
- self .device , MEMORY_FRACTION * TGI_WIGGLE_ROOM
1553
- )
1554
- real_free_memory = get_free_memory (self .device , MEMORY_FRACTION )
1555
- log_master (
1556
- logger .debug ,
1557
- f"Free memory { free_memory / 1e9 :.2f} GB , (real: { real_free_memory / 1e9 :.2f} GB" ,
1558
- )
1559
-
1560
1568
_ , _batch , _ = self .generate_token ([batch ])
1561
1569
except Exception :
1562
1570
raise RuntimeError (
@@ -1565,8 +1573,9 @@ def warmup(
1565
1573
)
1566
1574
1567
1575
synchronize (self .device )
1568
- free_memory = get_free_memory (self .device , MEMORY_FRACTION * TGI_WIGGLE_ROOM )
1569
- kv_memory = free_memory
1576
+ free_memory = get_free_memory (self .device , TGI_WIGGLE_ROOM )
1577
+
1578
+ kv_memory = free_memory - self .mem_reserved - mem_used_from_graph
1570
1579
num_blocks = (
1571
1580
# Leave 5% for some wiggle room
1572
1581
int (kv_memory // total_cache_size )
@@ -1583,7 +1592,6 @@ def warmup(
1583
1592
1584
1593
self .kv_cache = []
1585
1594
empty_cache ()
1586
-
1587
1595
self .init_kv_cache (
1588
1596
num_blocks ,
1589
1597
self .num_layers ,
@@ -1595,11 +1603,16 @@ def warmup(
1595
1603
self .max_batch_prefill_tokens = get_max_prefill_tokens ()
1596
1604
max_num_seqs = int (os .getenv ("MAX_BATCH_SIZE" ))
1597
1605
HPUBucketingContext = get_bucketing_context ()
1598
- max_total_tokens_aligned = math .ceil (max_total_tokens / BLOCK_SIZE ) * BLOCK_SIZE
1606
+ # need to warmup one more step since block is allocated from 1
1607
+ block_step = os .getenv ("VLLM_DECODE_BLOCK_BUCKET_STEP" , BLOCK_SIZE )
1608
+ max_total_tokens_aligned = math .ceil (
1609
+ max_total_tokens / BLOCK_SIZE
1610
+ ) * BLOCK_SIZE + math .ceil (block_step * BLOCK_SIZE / max_num_seqs )
1599
1611
model_max_length = self .tokenizer .model_max_length
1600
1612
max_position_embeddings = getattr (
1601
1613
self .config , "max_position_embeddings" , model_max_length
1602
1614
)
1615
+
1603
1616
self .bucketing_ctx = HPUBucketingContext (
1604
1617
max_num_seqs ,
1605
1618
max_num_seqs , # self.max_num_prefill_seqs, #TODO
@@ -1610,31 +1623,75 @@ def warmup(
1610
1623
max_input_tokens ,
1611
1624
max_total_tokens_aligned ,
1612
1625
)
1613
- max_blocks = (
1614
- max ( BLOCK_SIZE , max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE ) + 1
1626
+ max_blocks = max (
1627
+ BLOCK_SIZE , max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
1615
1628
)
1616
1629
self .bucketing_ctx .num_hpu_blocks = min (max_blocks , num_blocks )
1617
- if os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true" :
1630
+ synchronize (self .device )
1631
+ if self .skip_warmup :
1618
1632
self .bucketing_ctx .generate_prompt_buckets ()
1619
1633
self .bucketing_ctx .generate_decode_buckets (
1620
1634
self .bucketing_ctx .num_hpu_blocks
1621
1635
)
1622
- logger .info ("skip warmup hpu graph, not recommmended" )
1636
+ log_master (
1637
+ logger .info , "skip warmup hpu graph, not recommmended, may cause OOM"
1638
+ )
1623
1639
del _batch , batch
1624
1640
return int (num_blocks * BLOCK_SIZE ), max_input_tokens , max_total_tokens
1625
-
1626
1641
self .warmup_hpu_graph (batch )
1627
1642
del _batch , batch
1628
1643
1629
1644
return int (num_blocks * BLOCK_SIZE ), max_input_tokens , max_total_tokens
1630
1645
1631
- def bypass_hpu_graphs (self , prefill , max_seq_len_to_capture ):
1632
- if self .limit_hpu_graph :
1633
- return prefill
1634
- else :
1635
- return prefill and max_seq_len_to_capture > self .max_seq_len_to_capture
1646
+ def log_warmup (self , prefilling , i , max_i , batch_size , seq_len ):
1647
+ free_mem = format_bytes (HabanaMemoryProfiler .current_free_device_memory ())
1648
+ phase = "Prompt" if prefilling else "Decode"
1649
+ dim = "seq_len" if prefilling else "num_blocks"
1650
+ graphed_bucket = (batch_size , seq_len , prefilling )
1651
+ bypass = graphed_bucket not in self .graphed_buckets
1652
+ msg = (
1653
+ f"[Warmup][{ phase } ][{ i + 1 } /{ max_i } ] "
1654
+ f"batch_size:{ batch_size } "
1655
+ f"{ dim } :{ seq_len } "
1656
+ f"bypass:{ bypass } "
1657
+ f"free_mem:{ free_mem } "
1658
+ )
1659
+ log_master (logger .info , msg )
1660
+
1661
+ def use_graphs (self , prefill , seq_len , batch_size ):
1662
+ if self .limit_hpu_graph and prefill :
1663
+ return False
1664
+
1665
+ if self .skip_warmup :
1666
+ return True
1667
+
1668
+ return (batch_size , seq_len , prefill ) in self .graphed_buckets
1669
+
1670
+ def align_workers (self , value , op ):
1671
+ if self .world_size <= 1 :
1672
+ return value
1673
+ value_t = torch .tensor (value , device = "cpu" )
1674
+ torch .distributed .all_reduce (value_t , op = op , group = self .process_group_cpu )
1675
+ return value_t .item ()
1636
1676
1637
1677
def warmup_hpu_graph (self , batch ):
1678
+ prompt_graph_mem_ratio = float (os .environ .get ("VLLM_GRAPH_PROMPT_RATIO" , "0.3" ))
1679
+ free_mem = HabanaMemoryProfiler .current_free_device_memory ()
1680
+ graph_free_mem = free_mem - self .mem_reserved
1681
+ graph_free_mem = self .align_workers (
1682
+ graph_free_mem , torch .distributed .ReduceOp .MIN
1683
+ )
1684
+ prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
1685
+ decode_available_memory = graph_free_mem - prompt_available_memory
1686
+ msg = (
1687
+ f"Using { format_bytes (graph_free_mem )} "
1688
+ f"/{ format_bytes (free_mem )} "
1689
+ "of free device memory for HPUGraphs, "
1690
+ f"{ format_bytes (prompt_available_memory )} for prompt and "
1691
+ f"{ format_bytes (decode_available_memory )} for decode "
1692
+ f"(VLLM_GRAPH_PROMPT_RATIO={ prompt_graph_mem_ratio } )"
1693
+ )
1694
+ log_master (logger .info , msg )
1638
1695
start_time = time .time ()
1639
1696
warmup_shape_count = 0
1640
1697
warmup_times = 3
@@ -1646,15 +1703,34 @@ def ordering_function_min_tokens(b):
1646
1703
buckets = list (
1647
1704
sorted (self .bucketing_ctx .prompt_buckets , key = ordering_function_min_tokens )
1648
1705
)
1649
-
1706
+ total_batch_seq = 0.001
1707
+ total_mem = 0
1708
+ available_mem = prompt_available_memory
1650
1709
for i , (batch_size , seq_len ) in enumerate (buckets ):
1651
1710
if batch_size * seq_len > self .max_batch_prefill_tokens :
1652
1711
continue
1712
+ # Graph memory usage is proportional to seq dimension in a batch
1713
+ batch_seq = batch_size * seq_len
1714
+ mem_estimate = batch_seq / total_batch_seq * total_mem
1715
+ graphed_bucket = (batch_size , seq_len , True )
1716
+ if not (
1717
+ mem_estimate >= available_mem or batch_seq > self .max_seq_len_to_capture
1718
+ ):
1719
+ if graphed_bucket not in self .graphed_buckets :
1720
+ self .graphed_buckets .add (graphed_bucket )
1653
1721
warmup_shape_count += 1
1654
- log_master (logger .info , f"warmup prefill seq { seq_len } bs { batch_size } " )
1655
- for index in range (warmup_times ):
1656
- self .warmup_prefill (seq_len , batch_size , batch )
1657
- synchronize (self .device )
1722
+ self .log_warmup (True , i , len (buckets ), batch_size , seq_len )
1723
+ with HabanaMemoryProfiler () as mem_prof :
1724
+ for index in range (warmup_times ):
1725
+ self .warmup_prefill (seq_len , batch_size , batch )
1726
+ synchronize (self .device )
1727
+ used_mem = self .align_workers (
1728
+ mem_prof .consumed_device_memory , torch .distributed .ReduceOp .MAX
1729
+ )
1730
+ if graphed_bucket in self .graphed_buckets :
1731
+ available_mem -= used_mem
1732
+ total_mem += used_mem
1733
+ total_batch_seq += batch_seq
1658
1734
1659
1735
def ordering_function_max_bs (b ):
1660
1736
return (- b [0 ], b [1 ])
@@ -1663,16 +1739,34 @@ def ordering_function_max_bs(b):
1663
1739
buckets = list (
1664
1740
sorted (self .bucketing_ctx .decode_buckets , key = ordering_function_max_bs )
1665
1741
)
1742
+ free_mem = HabanaMemoryProfiler .current_free_device_memory ()
1743
+ total_batch_seq = 0.001
1744
+ total_mem = 0
1745
+ available_mem = free_mem - self .mem_reserved
1666
1746
for i , (batch_size , block_num ) in enumerate (buckets ):
1667
1747
if batch_size > block_num :
1668
1748
continue
1749
+ # Graph memory usage is proportional to seq dimension in a batch
1750
+ batch_seq = batch_size
1751
+ mem_estimate = batch_seq / total_batch_seq * total_mem
1752
+ graphed_bucket = (batch_size , block_num , False )
1753
+ if not mem_estimate >= available_mem :
1754
+ if graphed_bucket not in self .graphed_buckets :
1755
+ self .graphed_buckets .add (graphed_bucket )
1669
1756
warmup_shape_count += 1
1670
- log_master (
1671
- logger .info , f"warmup decode bs { batch_size } block_num { block_num } "
1757
+ self .log_warmup (False , i , len (buckets ), batch_size , block_num )
1758
+ with HabanaMemoryProfiler () as mem_prof :
1759
+ for index in range (warmup_times ):
1760
+ self .warmup_decode (batch_size , block_num , batch )
1761
+ synchronize (self .device )
1762
+ used_mem = self .align_workers (
1763
+ mem_prof .consumed_device_memory , torch .distributed .ReduceOp .MAX
1672
1764
)
1673
- for index in range (warmup_times ):
1674
- self .warmup_decode (batch_size , block_num , batch )
1675
- synchronize (self .device )
1765
+ if graphed_bucket in self .graphed_buckets :
1766
+ available_mem -= used_mem
1767
+ total_mem += used_mem
1768
+ total_batch_seq += batch_seq
1769
+
1676
1770
log_master (
1677
1771
logger .info ,
1678
1772
f"warmup hpu graph time { int (time .time () - start_time )} s warmup shape count { warmup_shape_count } " ,
@@ -1707,8 +1801,8 @@ def warmup_prefill(
1707
1801
lm_head_indices = input_lengths - 1
1708
1802
kwargs = {}
1709
1803
if htorch .utils .internal .is_lazy ():
1710
- kwargs ["bypass_hpu_graphs" ] = self .bypass_hpu_graphs (
1711
- True , input_ids . shape [ 0 ]
1804
+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1805
+ True , prompt_len , batch_size
1712
1806
)
1713
1807
1714
1808
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
@@ -1762,7 +1856,9 @@ def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBat
1762
1856
slots_tensor = torch .tensor (slots , dtype = batch .slots .dtype )
1763
1857
kwargs = {}
1764
1858
if htorch .utils .internal .is_lazy ():
1765
- kwargs ["bypass_hpu_graphs" ] = False
1859
+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1860
+ False , hpu_attention_meta .block_list .shape [0 ], batch_size
1861
+ )
1766
1862
# We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
1767
1863
self .model .forward (
1768
1864
input_ids = _async_h2d_tensor_copy (input_ids ),
@@ -1858,8 +1954,14 @@ def forward(
1858
1954
1859
1955
kwargs = {}
1860
1956
if htorch .utils .internal .is_lazy ():
1861
- kwargs ["bypass_hpu_graphs" ] = self .bypass_hpu_graphs (
1862
- batch .prefilling , input_ids .shape [0 ]
1957
+ batch_size = input_lengths .shape [0 ]
1958
+ prompt_len = (
1959
+ input_ids .shape [0 ] // batch_size
1960
+ if batch .prefilling
1961
+ else batch .hpu_attn_meta .block_list .shape [0 ]
1962
+ )
1963
+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1964
+ batch .prefilling , prompt_len , batch_size
1863
1965
)
1864
1966
1865
1967
logits , speculative_logits = self .model .forward (
0 commit comments