Skip to content

Commit 8b4f968

Browse files
committed
improve simulator
1 parent cad53ea commit 8b4f968

File tree

17 files changed

+1841
-456
lines changed

17 files changed

+1841
-456
lines changed

tools/simulator/AGENTS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,3 +31,5 @@
3131
## Security & Configuration Tips
3232
- Review `internal/configs/hardware_params.py` and `examples/env.json` before adding hardware profiles; never commit production-specific credentials.
3333
- Treat environment-change JSONL fixtures as append-only—add new files for new scenarios instead of rewriting shared samples.
34+
35+
If you need to use python, please use this one: $HOME/micromamba/envs/pg/bin/python

tools/simulator/cli/run_simulator.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212

1313
def run_simulation(args):
1414
print(args)
15+
print_interval = max(0.0, float(args.print_interval))
16+
1517
workload = load_trace(
1618
args.input,
1719
float(args.arrival_rate),
@@ -36,15 +38,15 @@ def run_simulation(args):
3638
server = NodeGlobalEngine(
3739
environment_config=environment_config,
3840
environment_changes=environment_changes,
39-
print_interval=args.print_interval,
41+
print_interval=print_interval,
4042
)
4143
else:
4244
# Fallback to legacy LLMGlobalEngine for backward compatibility
4345
print("Using Legacy Engine-based Global Engine")
4446
server = LLMGlobalEngine(
4547
environment_config=environment_config,
4648
environment_changes=environment_changes,
47-
print_interval=args.print_interval,
49+
print_interval=print_interval,
4850
)
4951

5052
# If no environment config is provided, use the old method
@@ -130,8 +132,8 @@ def run_simulation(args):
130132
parser.add_argument(
131133
"--print-interval",
132134
type=float,
133-
help="Print interval for progress updates in seconds (default: 0.1)",
134-
default=0.1,
135+
help=("Seconds between progress updates; set to 0 to disable"),
136+
default=0.5,
135137
)
136138
args = parser.parse_args()
137139
run_simulation(args)

tools/simulator/core/engine.py

Lines changed: 168 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,27 @@
11
from collections import deque
2+
from typing import Any, Deque, Dict, List, Optional
3+
24
from internal.analyzer import ModelAnalyzer
3-
from .trace import TraceEvent
4-
from .memory_planner import MemoryPlanner
55
from internal.configs.hardware_params import hardware_params
6-
from typing import List, Deque
6+
7+
from .memory_planner import MemoryPlanner, ParallelConfig
78
from .request import GenerationRequest
9+
from .trace import TraceEvent
810

911

1012
class LLMEngine:
11-
def __init__(self, engine_id, model_name, hardware_name, w_bit, a_bit, kv_bit):
13+
def __init__(
14+
self,
15+
engine_id,
16+
model_name,
17+
hardware_name,
18+
w_bit,
19+
a_bit,
20+
kv_bit,
21+
*,
22+
parallel_config: Optional[ParallelConfig] = None,
23+
memory_override_bytes: Optional[float] = None,
24+
):
1225
"""
1326
Initialize a single LLM inference engine.
1427
@@ -19,6 +32,8 @@ def __init__(self, engine_id, model_name, hardware_name, w_bit, a_bit, kv_bit):
1932
w_bit: Weight precision in bits (e.g., 16 for FP16, 8 for INT8)
2033
a_bit: Activation precision in bits
2134
kv_bit: KV cache precision in bits
35+
parallel_config: Optional tensor/pipeline parallel configuration
36+
memory_override_bytes: Override for device memory (bytes) per shard
2237
"""
2338
self.engine_id = engine_id
2439
self.model_name = model_name
@@ -36,14 +51,24 @@ def __init__(self, engine_id, model_name, hardware_name, w_bit, a_bit, kv_bit):
3651
self.running: Deque[GenerationRequest] = deque()
3752
self.finished: List[GenerationRequest] = []
3853
self.failed: List[GenerationRequest] = []
54+
base_hardware = hardware_params.get(hardware_name)
55+
if base_hardware is None:
56+
raise ValueError(f"Unknown hardware profile: {hardware_name}")
57+
self.hardware_spec = dict(base_hardware)
58+
if memory_override_bytes is not None and memory_override_bytes > 0:
59+
self.hardware_spec["vmemory"] = max(
60+
memory_override_bytes, self.hardware_spec.get("vmemory", 0)
61+
)
62+
self.memory_override_bytes = memory_override_bytes
3963
self.memory_planner = MemoryPlanner(
4064
self.analyzer.model_params,
41-
hardware_params[hardware_name],
65+
self.hardware_spec,
4266
w_bit,
4367
a_bit,
4468
kv_bit,
69+
parallel_config=parallel_config,
4570
)
46-
self.memory_planner.print_status()
71+
self.parallel_config: ParallelConfig = self.memory_planner.parallel_config
4772
self.finished_requests: int = 0
4873
self.configure()
4974

@@ -54,6 +79,23 @@ def configure(self):
5479
"""
5580
pass
5681

82+
def update_parallel_config(self, parallel_config: ParallelConfig) -> None:
83+
"""
84+
Update tensor/pipeline parallel configuration and rebuild memory planner.
85+
86+
Args:
87+
parallel_config: New parallel configuration to apply.
88+
"""
89+
self.parallel_config = parallel_config
90+
self.memory_planner = MemoryPlanner(
91+
self.analyzer.model_params,
92+
self.hardware_spec,
93+
self.w_bit,
94+
self.a_bit,
95+
self.kv_bit,
96+
parallel_config=parallel_config,
97+
)
98+
5799
def add_request(self, request: GenerationRequest):
58100
"""
59101
Add a new request to the waiting queue.
@@ -85,7 +127,6 @@ def _prefill(self, request: GenerationRequest, start_at: float):
85127
return prefill_time + start_at, [request], memory_event
86128

87129
def _decode(self, requests: List[GenerationRequest], start_at: float):
88-
max_batch_size = len(requests)
89130
decode_time = []
90131
finished_requests_in_this_batch = []
91132
executable_requests = []
@@ -95,12 +136,14 @@ def _decode(self, requests: List[GenerationRequest], start_at: float):
95136
executable_requests.append(req)
96137
batch_size = len(executable_requests)
97138
memory_event = self.memory_event(start_at)
139+
if batch_size == 0:
140+
return start_at + 0.0001, [], memory_event, []
98141
for req in executable_requests:
99142
if start_at < req.arrive_at:
100143
start_at = req.arrive_at
101144
decode_result = self.analyzer.analyze(
102145
req.input_length + req.generated_tokens,
103-
batchsize=max_batch_size,
146+
batchsize=batch_size,
104147
w_bit=self.w_bit,
105148
a_bit=self.a_bit,
106149
kv_bit=self.kv_bit,
@@ -142,7 +185,6 @@ def step(self, start_at: float):
142185
- memory_event: Memory usage event for tracing
143186
"""
144187
# let's assume that process one request per step is fine in terms of utilization
145-
handled_requests = []
146188
# self.memory_planner.print_status()
147189

148190
if len(self.waiting) > 0:
@@ -160,7 +202,6 @@ def step(self, start_at: float):
160202
if allocatable_request:
161203
# Remove the request from the queue and process it
162204
self.waiting.remove(allocatable_request)
163-
handled_requests = [allocatable_request.req_id]
164205
prefill_end_at, handled_requests, memory_event = self._prefill(
165206
allocatable_request, start_at
166207
)
@@ -195,9 +236,8 @@ def step(self, start_at: float):
195236
memory_event,
196237
)
197238
else:
198-
# add a shift to the timer,
199-
# since we need to move on
200-
return None, [], start_at + 0.0001, None
239+
# No work pending; signal that the engine can stay idle until new requests arrive
240+
return None, [], None, None
201241

202242
def create_event(self, phase, handled_requests, start_at, end_at):
203243
"""
@@ -213,18 +253,52 @@ def create_event(self, phase, handled_requests, start_at, end_at):
213253
List of TraceEvent objects compatible with Chrome tracing format
214254
"""
215255
complete_events = []
216-
handled_requests = [req.req_id for req in handled_requests]
256+
start_us = int(max(start_at, 0) * 1_000_000)
257+
duration_s = max(end_at - start_at, 0.0)
258+
duration_us = max(int(duration_s * 1_000_000), 1)
259+
217260
for req in handled_requests:
218-
complete = TraceEvent(
219-
name=f"{phase}-{req}",
220-
cat=f"{phase,req}",
221-
ph="X",
222-
pid=self.engine_id,
223-
tid=0,
224-
ts=int(start_at * 1000 * 1000), # convert to microseconds
225-
dur=int((end_at - start_at) * 1000 * 1000),
261+
event_args = {
262+
"request_id": req.req_id,
263+
"requested_model": req.model,
264+
"engine_id": str(self.engine_id),
265+
"engine_model": self.model_name,
266+
"hardware": self.hardware_name,
267+
"phase": phase,
268+
"start_time_s": round(start_at, 6),
269+
"end_time_s": round(end_at, 6),
270+
"duration_s": round(duration_s, 6),
271+
}
272+
273+
if phase == "prefill":
274+
event_args.update(
275+
{
276+
"prompt_tokens": req.input_length,
277+
"target_output_tokens": req.output_length,
278+
}
279+
)
280+
elif phase == "decode":
281+
event_args.update(
282+
{
283+
"target_output_tokens": req.output_length,
284+
"generated_tokens_total": req.generated_tokens,
285+
"tokens_emitted_this_step": 1,
286+
}
287+
)
288+
289+
complete_events.append(
290+
TraceEvent(
291+
name=f"{phase.upper()[0]}:{req.req_id}",
292+
cat=f"request.{phase}",
293+
ph="X",
294+
pid=str(self.engine_id),
295+
tid=0,
296+
ts=start_us,
297+
dur=duration_us,
298+
args=event_args,
299+
)
226300
)
227-
complete_events.append(complete)
301+
228302
return complete_events
229303

230304
def memory_event(self, start_at):
@@ -237,20 +311,48 @@ def memory_event(self, start_at):
237311
Returns:
238312
TraceEvent representing current memory block usage
239313
"""
314+
used_blocks, total_blocks = self.memory_planner.usage()
240315
return TraceEvent(
241-
name="block usage",
316+
name="Memory usage",
242317
ph="C",
243-
ts=start_at * 1e6,
244-
pid=self.engine_id,
318+
ts=int(max(start_at, 0) * 1_000_000),
319+
pid=str(self.engine_id),
245320
tid=0,
246-
cat="memory",
321+
cat="memory.allocator",
247322
args={
248-
"used": self.memory_planner._allocated_blocks,
249-
"free": self.memory_planner._max_num_blocks
250-
- self.memory_planner._allocated_blocks,
323+
"used_blocks": used_blocks,
324+
"free_blocks": total_blocks - used_blocks,
325+
"engine_model": self.model_name,
326+
"hardware": self.hardware_name,
327+
"tensor_parallel_size": self.parallel_config.tensor_parallel_size,
328+
"pipeline_parallel_size": self.parallel_config.pipeline_parallel_size,
329+
"waiting_requests": [req.req_id for req in self.waiting],
330+
"running_requests": [req.req_id for req in self.running],
251331
},
252332
)
253333

334+
def status_snapshot(self) -> Dict[str, Any]:
335+
"""Return a concise status summary for periodic logging."""
336+
used_blocks, total_blocks = self.memory_planner.usage()
337+
if len(self.running) > 0:
338+
state = "busy"
339+
elif len(self.waiting) > 0:
340+
state = "queued"
341+
else:
342+
state = "idle"
343+
344+
return {
345+
"engine_id": str(self.engine_id),
346+
"model": self.model_name,
347+
"hardware": self.hardware_name,
348+
"tensor_parallel_size": self.parallel_config.tensor_parallel_size,
349+
"state": state,
350+
"used_blocks": used_blocks,
351+
"total_blocks": total_blocks,
352+
"waiting": len(self.waiting),
353+
"running": len(self.running),
354+
}
355+
254356
@property
255357
def empty(self):
256358
"""
@@ -260,3 +362,39 @@ def empty(self):
260362
bool: True if both waiting and running queues are empty
261363
"""
262364
return len(self.waiting) == 0 and len(self.running) == 0
365+
366+
def reconfigure_model(
367+
self,
368+
model_name: str,
369+
*,
370+
w_bit: Optional[int] = None,
371+
a_bit: Optional[int] = None,
372+
kv_bit: Optional[int] = None,
373+
) -> None:
374+
"""Retarget this engine to serve a different model."""
375+
if self.waiting or self.running:
376+
raise RuntimeError("Cannot reconfigure engine while requests are in-flight")
377+
378+
if w_bit is not None:
379+
self.w_bit = w_bit
380+
if a_bit is not None:
381+
self.a_bit = a_bit
382+
if kv_bit is not None:
383+
self.kv_bit = kv_bit
384+
385+
self.model_name = model_name
386+
self.analyzer = ModelAnalyzer(
387+
model_id=model_name,
388+
hardware=self.hardware_name,
389+
config_file="internal/configs/llama.py",
390+
source="huggingface",
391+
)
392+
self.memory_planner = MemoryPlanner(
393+
self.analyzer.model_params,
394+
self.hardware_spec,
395+
self.w_bit,
396+
self.a_bit,
397+
self.kv_bit,
398+
parallel_config=self.parallel_config,
399+
)
400+
self.parallel_config = self.memory_planner.parallel_config

0 commit comments

Comments
 (0)