11from collections import deque
2+ from typing import Any , Deque , Dict , List , Optional
3+
24from internal .analyzer import ModelAnalyzer
3- from .trace import TraceEvent
4- from .memory_planner import MemoryPlanner
55from internal .configs .hardware_params import hardware_params
6- from typing import List , Deque
6+
7+ from .memory_planner import MemoryPlanner , ParallelConfig
78from .request import GenerationRequest
9+ from .trace import TraceEvent
810
911
1012class LLMEngine :
11- def __init__ (self , engine_id , model_name , hardware_name , w_bit , a_bit , kv_bit ):
13+ def __init__ (
14+ self ,
15+ engine_id ,
16+ model_name ,
17+ hardware_name ,
18+ w_bit ,
19+ a_bit ,
20+ kv_bit ,
21+ * ,
22+ parallel_config : Optional [ParallelConfig ] = None ,
23+ memory_override_bytes : Optional [float ] = None ,
24+ ):
1225 """
1326 Initialize a single LLM inference engine.
1427
@@ -19,6 +32,8 @@ def __init__(self, engine_id, model_name, hardware_name, w_bit, a_bit, kv_bit):
1932 w_bit: Weight precision in bits (e.g., 16 for FP16, 8 for INT8)
2033 a_bit: Activation precision in bits
2134 kv_bit: KV cache precision in bits
35+ parallel_config: Optional tensor/pipeline parallel configuration
36+ memory_override_bytes: Override for device memory (bytes) per shard
2237 """
2338 self .engine_id = engine_id
2439 self .model_name = model_name
@@ -36,14 +51,24 @@ def __init__(self, engine_id, model_name, hardware_name, w_bit, a_bit, kv_bit):
3651 self .running : Deque [GenerationRequest ] = deque ()
3752 self .finished : List [GenerationRequest ] = []
3853 self .failed : List [GenerationRequest ] = []
54+ base_hardware = hardware_params .get (hardware_name )
55+ if base_hardware is None :
56+ raise ValueError (f"Unknown hardware profile: { hardware_name } " )
57+ self .hardware_spec = dict (base_hardware )
58+ if memory_override_bytes is not None and memory_override_bytes > 0 :
59+ self .hardware_spec ["vmemory" ] = max (
60+ memory_override_bytes , self .hardware_spec .get ("vmemory" , 0 )
61+ )
62+ self .memory_override_bytes = memory_override_bytes
3963 self .memory_planner = MemoryPlanner (
4064 self .analyzer .model_params ,
41- hardware_params [ hardware_name ] ,
65+ self . hardware_spec ,
4266 w_bit ,
4367 a_bit ,
4468 kv_bit ,
69+ parallel_config = parallel_config ,
4570 )
46- self .memory_planner .print_status ()
71+ self .parallel_config : ParallelConfig = self . memory_planner .parallel_config
4772 self .finished_requests : int = 0
4873 self .configure ()
4974
@@ -54,6 +79,23 @@ def configure(self):
5479 """
5580 pass
5681
82+ def update_parallel_config (self , parallel_config : ParallelConfig ) -> None :
83+ """
84+ Update tensor/pipeline parallel configuration and rebuild memory planner.
85+
86+ Args:
87+ parallel_config: New parallel configuration to apply.
88+ """
89+ self .parallel_config = parallel_config
90+ self .memory_planner = MemoryPlanner (
91+ self .analyzer .model_params ,
92+ self .hardware_spec ,
93+ self .w_bit ,
94+ self .a_bit ,
95+ self .kv_bit ,
96+ parallel_config = parallel_config ,
97+ )
98+
5799 def add_request (self , request : GenerationRequest ):
58100 """
59101 Add a new request to the waiting queue.
@@ -85,7 +127,6 @@ def _prefill(self, request: GenerationRequest, start_at: float):
85127 return prefill_time + start_at , [request ], memory_event
86128
87129 def _decode (self , requests : List [GenerationRequest ], start_at : float ):
88- max_batch_size = len (requests )
89130 decode_time = []
90131 finished_requests_in_this_batch = []
91132 executable_requests = []
@@ -95,12 +136,14 @@ def _decode(self, requests: List[GenerationRequest], start_at: float):
95136 executable_requests .append (req )
96137 batch_size = len (executable_requests )
97138 memory_event = self .memory_event (start_at )
139+ if batch_size == 0 :
140+ return start_at + 0.0001 , [], memory_event , []
98141 for req in executable_requests :
99142 if start_at < req .arrive_at :
100143 start_at = req .arrive_at
101144 decode_result = self .analyzer .analyze (
102145 req .input_length + req .generated_tokens ,
103- batchsize = max_batch_size ,
146+ batchsize = batch_size ,
104147 w_bit = self .w_bit ,
105148 a_bit = self .a_bit ,
106149 kv_bit = self .kv_bit ,
@@ -142,7 +185,6 @@ def step(self, start_at: float):
142185 - memory_event: Memory usage event for tracing
143186 """
144187 # let's assume that process one request per step is fine in terms of utilization
145- handled_requests = []
146188 # self.memory_planner.print_status()
147189
148190 if len (self .waiting ) > 0 :
@@ -160,7 +202,6 @@ def step(self, start_at: float):
160202 if allocatable_request :
161203 # Remove the request from the queue and process it
162204 self .waiting .remove (allocatable_request )
163- handled_requests = [allocatable_request .req_id ]
164205 prefill_end_at , handled_requests , memory_event = self ._prefill (
165206 allocatable_request , start_at
166207 )
@@ -195,9 +236,8 @@ def step(self, start_at: float):
195236 memory_event ,
196237 )
197238 else :
198- # add a shift to the timer,
199- # since we need to move on
200- return None , [], start_at + 0.0001 , None
239+ # No work pending; signal that the engine can stay idle until new requests arrive
240+ return None , [], None , None
201241
202242 def create_event (self , phase , handled_requests , start_at , end_at ):
203243 """
@@ -213,18 +253,52 @@ def create_event(self, phase, handled_requests, start_at, end_at):
213253 List of TraceEvent objects compatible with Chrome tracing format
214254 """
215255 complete_events = []
216- handled_requests = [req .req_id for req in handled_requests ]
256+ start_us = int (max (start_at , 0 ) * 1_000_000 )
257+ duration_s = max (end_at - start_at , 0.0 )
258+ duration_us = max (int (duration_s * 1_000_000 ), 1 )
259+
217260 for req in handled_requests :
218- complete = TraceEvent (
219- name = f"{ phase } -{ req } " ,
220- cat = f"{ phase ,req } " ,
221- ph = "X" ,
222- pid = self .engine_id ,
223- tid = 0 ,
224- ts = int (start_at * 1000 * 1000 ), # convert to microseconds
225- dur = int ((end_at - start_at ) * 1000 * 1000 ),
261+ event_args = {
262+ "request_id" : req .req_id ,
263+ "requested_model" : req .model ,
264+ "engine_id" : str (self .engine_id ),
265+ "engine_model" : self .model_name ,
266+ "hardware" : self .hardware_name ,
267+ "phase" : phase ,
268+ "start_time_s" : round (start_at , 6 ),
269+ "end_time_s" : round (end_at , 6 ),
270+ "duration_s" : round (duration_s , 6 ),
271+ }
272+
273+ if phase == "prefill" :
274+ event_args .update (
275+ {
276+ "prompt_tokens" : req .input_length ,
277+ "target_output_tokens" : req .output_length ,
278+ }
279+ )
280+ elif phase == "decode" :
281+ event_args .update (
282+ {
283+ "target_output_tokens" : req .output_length ,
284+ "generated_tokens_total" : req .generated_tokens ,
285+ "tokens_emitted_this_step" : 1 ,
286+ }
287+ )
288+
289+ complete_events .append (
290+ TraceEvent (
291+ name = f"{ phase .upper ()[0 ]} :{ req .req_id } " ,
292+ cat = f"request.{ phase } " ,
293+ ph = "X" ,
294+ pid = str (self .engine_id ),
295+ tid = 0 ,
296+ ts = start_us ,
297+ dur = duration_us ,
298+ args = event_args ,
299+ )
226300 )
227- complete_events . append ( complete )
301+
228302 return complete_events
229303
230304 def memory_event (self , start_at ):
@@ -237,20 +311,48 @@ def memory_event(self, start_at):
237311 Returns:
238312 TraceEvent representing current memory block usage
239313 """
314+ used_blocks , total_blocks = self .memory_planner .usage ()
240315 return TraceEvent (
241- name = "block usage" ,
316+ name = "Memory usage" ,
242317 ph = "C" ,
243- ts = start_at * 1e6 ,
244- pid = self .engine_id ,
318+ ts = int ( max ( start_at , 0 ) * 1_000_000 ) ,
319+ pid = str ( self .engine_id ) ,
245320 tid = 0 ,
246- cat = "memory" ,
321+ cat = "memory.allocator " ,
247322 args = {
248- "used" : self .memory_planner ._allocated_blocks ,
249- "free" : self .memory_planner ._max_num_blocks
250- - self .memory_planner ._allocated_blocks ,
323+ "used_blocks" : used_blocks ,
324+ "free_blocks" : total_blocks - used_blocks ,
325+ "engine_model" : self .model_name ,
326+ "hardware" : self .hardware_name ,
327+ "tensor_parallel_size" : self .parallel_config .tensor_parallel_size ,
328+ "pipeline_parallel_size" : self .parallel_config .pipeline_parallel_size ,
329+ "waiting_requests" : [req .req_id for req in self .waiting ],
330+ "running_requests" : [req .req_id for req in self .running ],
251331 },
252332 )
253333
334+ def status_snapshot (self ) -> Dict [str , Any ]:
335+ """Return a concise status summary for periodic logging."""
336+ used_blocks , total_blocks = self .memory_planner .usage ()
337+ if len (self .running ) > 0 :
338+ state = "busy"
339+ elif len (self .waiting ) > 0 :
340+ state = "queued"
341+ else :
342+ state = "idle"
343+
344+ return {
345+ "engine_id" : str (self .engine_id ),
346+ "model" : self .model_name ,
347+ "hardware" : self .hardware_name ,
348+ "tensor_parallel_size" : self .parallel_config .tensor_parallel_size ,
349+ "state" : state ,
350+ "used_blocks" : used_blocks ,
351+ "total_blocks" : total_blocks ,
352+ "waiting" : len (self .waiting ),
353+ "running" : len (self .running ),
354+ }
355+
254356 @property
255357 def empty (self ):
256358 """
@@ -260,3 +362,39 @@ def empty(self):
260362 bool: True if both waiting and running queues are empty
261363 """
262364 return len (self .waiting ) == 0 and len (self .running ) == 0
365+
366+ def reconfigure_model (
367+ self ,
368+ model_name : str ,
369+ * ,
370+ w_bit : Optional [int ] = None ,
371+ a_bit : Optional [int ] = None ,
372+ kv_bit : Optional [int ] = None ,
373+ ) -> None :
374+ """Retarget this engine to serve a different model."""
375+ if self .waiting or self .running :
376+ raise RuntimeError ("Cannot reconfigure engine while requests are in-flight" )
377+
378+ if w_bit is not None :
379+ self .w_bit = w_bit
380+ if a_bit is not None :
381+ self .a_bit = a_bit
382+ if kv_bit is not None :
383+ self .kv_bit = kv_bit
384+
385+ self .model_name = model_name
386+ self .analyzer = ModelAnalyzer (
387+ model_id = model_name ,
388+ hardware = self .hardware_name ,
389+ config_file = "internal/configs/llama.py" ,
390+ source = "huggingface" ,
391+ )
392+ self .memory_planner = MemoryPlanner (
393+ self .analyzer .model_params ,
394+ self .hardware_spec ,
395+ self .w_bit ,
396+ self .a_bit ,
397+ self .kv_bit ,
398+ parallel_config = self .parallel_config ,
399+ )
400+ self .parallel_config = self .memory_planner .parallel_config
0 commit comments