7676import habana_frameworks .torch as htorch
7777import itertools
7878from vllm_hpu_extension .bucketing .common import get_bucketing_context
79+ from vllm_hpu_extension .profiler import HabanaMemoryProfiler , format_bytes
7980
8081tracer = trace .get_tracer (__name__ )
8182
@@ -1357,6 +1358,8 @@ def __init__(
13571358 ):
13581359 self .quantize = quantize
13591360 self .process_group , rank , world_size = initialize_torch_distributed ()
1361+ if world_size > 1 :
1362+ self .process_group_cpu = torch .distributed .new_group (backend = "gloo" )
13601363
13611364 device = torch .device ("hpu" )
13621365 dtype = torch .bfloat16 if dtype is None else dtype
@@ -1453,6 +1456,7 @@ def __init__(
14531456 self .limit_hpu_graph = (
14541457 os .environ .get ("LIMIT_HPU_GRAPH" , "false" ).lower () == "true"
14551458 )
1459+ self .skip_warmup = os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true"
14561460 self .max_seq_len_to_capture = 8192
14571461 super ().__init__ (
14581462 model_id = model_id ,
@@ -1521,7 +1525,7 @@ def warmup(
15211525 # The warmup batch is the biggest batch we could ever receive
15221526 self .kv_cache = []
15231527 empty_cache ()
1524-
1528+ self . graphed_buckets = set ()
15251529 # Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
15261530 # Calculate the number of blocks that can be allocated with the free memory
15271531 dtype_size = torch .tensor ([], dtype = self .kv_cache_dtype ).element_size ()
@@ -1533,7 +1537,20 @@ def warmup(
15331537 cache_block_size = BLOCK_SIZE * self .num_kv_heads * self .head_size
15341538 cache_block_size = cache_block_size * 2
15351539 total_cache_size = self .num_layers * cache_block_size * dtype_size
1536-
1540+ free_memory = get_free_memory (self .device , TGI_WIGGLE_ROOM )
1541+ self .mem_reserved = int (free_memory * (1 - MEMORY_FRACTION ))
1542+ graph_reserved_mem = (
1543+ float (os .environ .get ("TGI_GRAPH_RESERVED_MEM" , "0.1" ))
1544+ if htorch .utils .internal .is_lazy ()
1545+ else 0
1546+ )
1547+ mem_used_from_graph = int (
1548+ (free_memory - self .mem_reserved ) * graph_reserved_mem
1549+ )
1550+ log_master (
1551+ logger .info ,
1552+ f"Free memory on device { self .device } : { format_bytes (free_memory )} used_for_graph: { format_bytes (mem_used_from_graph )} ratio { graph_reserved_mem } reserved_for_runtime: { format_bytes (self .mem_reserved )} " ,
1553+ )
15371554 try :
15381555 self .init_kv_cache (
15391556 batch .num_blocks ,
@@ -1548,15 +1565,6 @@ def warmup(
15481565
15491566 num_tokens = batch .to_pb ().current_tokens
15501567 synchronize (self .device )
1551- free_memory = get_free_memory (
1552- self .device , MEMORY_FRACTION * TGI_WIGGLE_ROOM
1553- )
1554- real_free_memory = get_free_memory (self .device , MEMORY_FRACTION )
1555- log_master (
1556- logger .debug ,
1557- f"Free memory { free_memory / 1e9 :.2f} GB , (real: { real_free_memory / 1e9 :.2f} GB" ,
1558- )
1559-
15601568 _ , _batch , _ = self .generate_token ([batch ])
15611569 except Exception :
15621570 raise RuntimeError (
@@ -1565,8 +1573,9 @@ def warmup(
15651573 )
15661574
15671575 synchronize (self .device )
1568- free_memory = get_free_memory (self .device , MEMORY_FRACTION * TGI_WIGGLE_ROOM )
1569- kv_memory = free_memory
1576+ free_memory = get_free_memory (self .device , TGI_WIGGLE_ROOM )
1577+
1578+ kv_memory = free_memory - self .mem_reserved - mem_used_from_graph
15701579 num_blocks = (
15711580 # Leave 5% for some wiggle room
15721581 int (kv_memory // total_cache_size )
@@ -1583,7 +1592,6 @@ def warmup(
15831592
15841593 self .kv_cache = []
15851594 empty_cache ()
1586-
15871595 self .init_kv_cache (
15881596 num_blocks ,
15891597 self .num_layers ,
@@ -1595,11 +1603,16 @@ def warmup(
15951603 self .max_batch_prefill_tokens = get_max_prefill_tokens ()
15961604 max_num_seqs = int (os .getenv ("MAX_BATCH_SIZE" ))
15971605 HPUBucketingContext = get_bucketing_context ()
1598- max_total_tokens_aligned = math .ceil (max_total_tokens / BLOCK_SIZE ) * BLOCK_SIZE
1606+ # need to warmup one more step since block is allocated from 1
1607+ block_step = os .getenv ("VLLM_DECODE_BLOCK_BUCKET_STEP" , BLOCK_SIZE )
1608+ max_total_tokens_aligned = math .ceil (
1609+ max_total_tokens / BLOCK_SIZE
1610+ ) * BLOCK_SIZE + math .ceil (block_step * BLOCK_SIZE / max_num_seqs )
15991611 model_max_length = self .tokenizer .model_max_length
16001612 max_position_embeddings = getattr (
16011613 self .config , "max_position_embeddings" , model_max_length
16021614 )
1615+
16031616 self .bucketing_ctx = HPUBucketingContext (
16041617 max_num_seqs ,
16051618 max_num_seqs , # self.max_num_prefill_seqs, #TODO
@@ -1610,31 +1623,75 @@ def warmup(
16101623 max_input_tokens ,
16111624 max_total_tokens_aligned ,
16121625 )
1613- max_blocks = (
1614- max ( BLOCK_SIZE , max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE ) + 1
1626+ max_blocks = max (
1627+ BLOCK_SIZE , max_num_seqs * max_total_tokens_aligned // BLOCK_SIZE
16151628 )
16161629 self .bucketing_ctx .num_hpu_blocks = min (max_blocks , num_blocks )
1617- if os .getenv ("VLLM_SKIP_WARMUP" , "false" ).lower () == "true" :
1630+ synchronize (self .device )
1631+ if self .skip_warmup :
16181632 self .bucketing_ctx .generate_prompt_buckets ()
16191633 self .bucketing_ctx .generate_decode_buckets (
16201634 self .bucketing_ctx .num_hpu_blocks
16211635 )
1622- logger .info ("skip warmup hpu graph, not recommmended" )
1636+ log_master (
1637+ logger .info , "skip warmup hpu graph, not recommmended, may cause OOM"
1638+ )
16231639 del _batch , batch
16241640 return int (num_blocks * BLOCK_SIZE ), max_input_tokens , max_total_tokens
1625-
16261641 self .warmup_hpu_graph (batch )
16271642 del _batch , batch
16281643
16291644 return int (num_blocks * BLOCK_SIZE ), max_input_tokens , max_total_tokens
16301645
1631- def bypass_hpu_graphs (self , prefill , max_seq_len_to_capture ):
1632- if self .limit_hpu_graph :
1633- return prefill
1634- else :
1635- return prefill and max_seq_len_to_capture > self .max_seq_len_to_capture
1646+ def log_warmup (self , prefilling , i , max_i , batch_size , seq_len ):
1647+ free_mem = format_bytes (HabanaMemoryProfiler .current_free_device_memory ())
1648+ phase = "Prompt" if prefilling else "Decode"
1649+ dim = "seq_len" if prefilling else "num_blocks"
1650+ graphed_bucket = (batch_size , seq_len , prefilling )
1651+ bypass = graphed_bucket not in self .graphed_buckets
1652+ msg = (
1653+ f"[Warmup][{ phase } ][{ i + 1 } /{ max_i } ] "
1654+ f"batch_size:{ batch_size } "
1655+ f"{ dim } :{ seq_len } "
1656+ f"bypass:{ bypass } "
1657+ f"free_mem:{ free_mem } "
1658+ )
1659+ log_master (logger .info , msg )
1660+
1661+ def use_graphs (self , prefill , seq_len , batch_size ):
1662+ if self .limit_hpu_graph and prefill :
1663+ return False
1664+
1665+ if self .skip_warmup :
1666+ return True
1667+
1668+ return (batch_size , seq_len , prefill ) in self .graphed_buckets
1669+
1670+ def align_workers (self , value , op ):
1671+ if self .world_size <= 1 :
1672+ return value
1673+ value_t = torch .tensor (value , device = "cpu" )
1674+ torch .distributed .all_reduce (value_t , op = op , group = self .process_group_cpu )
1675+ return value_t .item ()
16361676
16371677 def warmup_hpu_graph (self , batch ):
1678+ prompt_graph_mem_ratio = float (os .environ .get ("VLLM_GRAPH_PROMPT_RATIO" , "0.3" ))
1679+ free_mem = HabanaMemoryProfiler .current_free_device_memory ()
1680+ graph_free_mem = free_mem - self .mem_reserved
1681+ graph_free_mem = self .align_workers (
1682+ graph_free_mem , torch .distributed .ReduceOp .MIN
1683+ )
1684+ prompt_available_memory = prompt_graph_mem_ratio * graph_free_mem
1685+ decode_available_memory = graph_free_mem - prompt_available_memory
1686+ msg = (
1687+ f"Using { format_bytes (graph_free_mem )} "
1688+ f"/{ format_bytes (free_mem )} "
1689+ "of free device memory for HPUGraphs, "
1690+ f"{ format_bytes (prompt_available_memory )} for prompt and "
1691+ f"{ format_bytes (decode_available_memory )} for decode "
1692+ f"(VLLM_GRAPH_PROMPT_RATIO={ prompt_graph_mem_ratio } )"
1693+ )
1694+ log_master (logger .info , msg )
16381695 start_time = time .time ()
16391696 warmup_shape_count = 0
16401697 warmup_times = 3
@@ -1646,15 +1703,34 @@ def ordering_function_min_tokens(b):
16461703 buckets = list (
16471704 sorted (self .bucketing_ctx .prompt_buckets , key = ordering_function_min_tokens )
16481705 )
1649-
1706+ total_batch_seq = 0.001
1707+ total_mem = 0
1708+ available_mem = prompt_available_memory
16501709 for i , (batch_size , seq_len ) in enumerate (buckets ):
16511710 if batch_size * seq_len > self .max_batch_prefill_tokens :
16521711 continue
1712+ # Graph memory usage is proportional to seq dimension in a batch
1713+ batch_seq = batch_size * seq_len
1714+ mem_estimate = batch_seq / total_batch_seq * total_mem
1715+ graphed_bucket = (batch_size , seq_len , True )
1716+ if not (
1717+ mem_estimate >= available_mem or batch_seq > self .max_seq_len_to_capture
1718+ ):
1719+ if graphed_bucket not in self .graphed_buckets :
1720+ self .graphed_buckets .add (graphed_bucket )
16531721 warmup_shape_count += 1
1654- log_master (logger .info , f"warmup prefill seq { seq_len } bs { batch_size } " )
1655- for index in range (warmup_times ):
1656- self .warmup_prefill (seq_len , batch_size , batch )
1657- synchronize (self .device )
1722+ self .log_warmup (True , i , len (buckets ), batch_size , seq_len )
1723+ with HabanaMemoryProfiler () as mem_prof :
1724+ for index in range (warmup_times ):
1725+ self .warmup_prefill (seq_len , batch_size , batch )
1726+ synchronize (self .device )
1727+ used_mem = self .align_workers (
1728+ mem_prof .consumed_device_memory , torch .distributed .ReduceOp .MAX
1729+ )
1730+ if graphed_bucket in self .graphed_buckets :
1731+ available_mem -= used_mem
1732+ total_mem += used_mem
1733+ total_batch_seq += batch_seq
16581734
16591735 def ordering_function_max_bs (b ):
16601736 return (- b [0 ], b [1 ])
@@ -1663,16 +1739,34 @@ def ordering_function_max_bs(b):
16631739 buckets = list (
16641740 sorted (self .bucketing_ctx .decode_buckets , key = ordering_function_max_bs )
16651741 )
1742+ free_mem = HabanaMemoryProfiler .current_free_device_memory ()
1743+ total_batch_seq = 0.001
1744+ total_mem = 0
1745+ available_mem = free_mem - self .mem_reserved
16661746 for i , (batch_size , block_num ) in enumerate (buckets ):
16671747 if batch_size > block_num :
16681748 continue
1749+ # Graph memory usage is proportional to seq dimension in a batch
1750+ batch_seq = batch_size
1751+ mem_estimate = batch_seq / total_batch_seq * total_mem
1752+ graphed_bucket = (batch_size , block_num , False )
1753+ if not mem_estimate >= available_mem :
1754+ if graphed_bucket not in self .graphed_buckets :
1755+ self .graphed_buckets .add (graphed_bucket )
16691756 warmup_shape_count += 1
1670- log_master (
1671- logger .info , f"warmup decode bs { batch_size } block_num { block_num } "
1757+ self .log_warmup (False , i , len (buckets ), batch_size , block_num )
1758+ with HabanaMemoryProfiler () as mem_prof :
1759+ for index in range (warmup_times ):
1760+ self .warmup_decode (batch_size , block_num , batch )
1761+ synchronize (self .device )
1762+ used_mem = self .align_workers (
1763+ mem_prof .consumed_device_memory , torch .distributed .ReduceOp .MAX
16721764 )
1673- for index in range (warmup_times ):
1674- self .warmup_decode (batch_size , block_num , batch )
1675- synchronize (self .device )
1765+ if graphed_bucket in self .graphed_buckets :
1766+ available_mem -= used_mem
1767+ total_mem += used_mem
1768+ total_batch_seq += batch_seq
1769+
16761770 log_master (
16771771 logger .info ,
16781772 f"warmup hpu graph time { int (time .time () - start_time )} s warmup shape count { warmup_shape_count } " ,
@@ -1707,8 +1801,8 @@ def warmup_prefill(
17071801 lm_head_indices = input_lengths - 1
17081802 kwargs = {}
17091803 if htorch .utils .internal .is_lazy ():
1710- kwargs ["bypass_hpu_graphs" ] = self .bypass_hpu_graphs (
1711- True , input_ids . shape [ 0 ]
1804+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1805+ True , prompt_len , batch_size
17121806 )
17131807
17141808 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
@@ -1762,7 +1856,9 @@ def warmup_decode(self, batch_size: int, block_num: int, batch: FlashCausalLMBat
17621856 slots_tensor = torch .tensor (slots , dtype = batch .slots .dtype )
17631857 kwargs = {}
17641858 if htorch .utils .internal .is_lazy ():
1765- kwargs ["bypass_hpu_graphs" ] = False
1859+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1860+ False , hpu_attention_meta .block_list .shape [0 ], batch_size
1861+ )
17661862 # We pass a `cu_seqlen_prefill` in order not to have to deal with paged attention cache allocation/deallocation.
17671863 self .model .forward (
17681864 input_ids = _async_h2d_tensor_copy (input_ids ),
@@ -1858,8 +1954,14 @@ def forward(
18581954
18591955 kwargs = {}
18601956 if htorch .utils .internal .is_lazy ():
1861- kwargs ["bypass_hpu_graphs" ] = self .bypass_hpu_graphs (
1862- batch .prefilling , input_ids .shape [0 ]
1957+ batch_size = input_lengths .shape [0 ]
1958+ prompt_len = (
1959+ input_ids .shape [0 ] // batch_size
1960+ if batch .prefilling
1961+ else batch .hpu_attn_meta .block_list .shape [0 ]
1962+ )
1963+ kwargs ["bypass_hpu_graphs" ] = not self .use_graphs (
1964+ batch .prefilling , prompt_len , batch_size
18631965 )
18641966
18651967 logits , speculative_logits = self .model .forward (
0 commit comments