Skip to content

Commit 4349939

Browse files
authored
Minor fix (#46)
* leftover minor * minor * fix dockerfile for x86_64 to include hopper * fix benchmark script * minor
1 parent 8e610b7 commit 4349939

File tree

12 files changed

+398
-29
lines changed

12 files changed

+398
-29
lines changed

docker/Dockerfile.x86_64-cuda

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ LABEL org.opencontainers.image.licenses=Apache-2.0
66
LABEL org.opencontainers.image.architecture=amd64
77

88
ENV DEBIAN_FRONTEND=noninteractive
9-
ENV TRITEIA_COMPUTE_CAP=80
10-
ENV TORCH_CUDA_ARCH_LIST="8.0"
9+
ENV TRITEIA_COMPUTE_CAP=90
10+
ENV TORCH_CUDA_ARCH_LIST="8.0 8.6 9.0 9.0a"
1111
ENV FLASHINFER_ENABLE_AOT="1"
1212

1313
RUN apt update && apt upgrade -y

scratchpad/managers/tp_worker.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,13 @@
1010
from scratchpad.model_executor.model_runner import ModelRunner
1111
from scratchpad.config.model_config import ModelConfig
1212
from scratchpad.scheduler.schedule_batch import ModelWorkerBatch
13-
from scratchpad.memory.het_pool import HeterogeneousMHATokenToKVPool
1413
from scratchpad.model_executor.forward_info import ForwardBatch
14+
from scratchpad.memory import (
15+
ReqToTokenPool,
16+
HeterogeneousMHATokenToKVPool,
17+
TokenToKVPoolAllocator,
18+
)
19+
1520
from .structs import UpdateWeightReqInput
1621
from typing import Optional
1722
from scratchpad.server.args import global_args
@@ -27,9 +32,12 @@ def __init__(
2732
server_args: ServerArgs,
2833
nccl_port: int,
2934
dp_rank: Optional[int] = 0,
35+
req_to_token_pool: Optional[ReqToTokenPool] = None,
36+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
3037
):
3138
# Parse args
3239
logger.info(f"Initalizing model worker on GPU {gpu_id}, tp_rank: {tp_rank}")
40+
3341
self.tp_rank = tp_rank
3442
self.server_args = server_args
3543
# Init model and tokenizer

scratchpad/memory/__init__.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
from .pool import *
2+
from .het_pool import *
3+
from .radix_cache import *
4+
from .topping_pool import *
5+
from .chunk_cache import *

scratchpad/memory/pool.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from abc import ABC, abstractmethod
1+
import abc
22
from typing import List, Tuple, Union, TYPE_CHECKING
33
import torch
44
from scratchpad.utils import logger
@@ -144,6 +144,108 @@ def set_kv_buffer(
144144
raise NotImplementedError()
145145

146146

147+
class KVCache(abc.ABC):
148+
@abc.abstractmethod
149+
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
150+
raise NotImplementedError()
151+
152+
@abc.abstractmethod
153+
def get_value_buffer(self, layer_id: int) -> torch.Tensor:
154+
raise NotImplementedError()
155+
156+
@abc.abstractmethod
157+
def get_kv_buffer(self, layer_id: int) -> Tuple[torch.Tensor, torch.Tensor]:
158+
raise NotImplementedError()
159+
160+
@abc.abstractmethod
161+
def set_kv_buffer(
162+
self,
163+
layer: "RadixAttention",
164+
loc: torch.Tensor,
165+
cache_k: torch.Tensor,
166+
cache_v: torch.Tensor,
167+
) -> None:
168+
raise NotImplementedError()
169+
170+
@abc.abstractmethod
171+
def get_flat_data(self, indices):
172+
raise NotImplementedError()
173+
174+
@abc.abstractmethod
175+
def transfer(self, indices, flat_data):
176+
raise NotImplementedError()
177+
178+
@abc.abstractmethod
179+
def transfer_per_layer(self, indices, flat_data, layer_id):
180+
raise NotImplementedError()
181+
182+
def register_layer_transfer_counter(self, layer_transfer_counter):
183+
self.layer_transfer_counter = layer_transfer_counter
184+
185+
186+
class TokenToKVPoolAllocator:
187+
"""An allocator managing the indices to kv cache data."""
188+
189+
def __init__(
190+
self,
191+
size: int,
192+
dtype: torch.dtype,
193+
device: str,
194+
kvcache: KVCache,
195+
):
196+
self.size = size
197+
self.dtype = dtype
198+
self.device = device
199+
self.page_size = 1
200+
201+
self.free_slots = None
202+
self.is_not_in_free_group = True
203+
self.free_group = []
204+
self.clear()
205+
206+
self._kvcache = kvcache
207+
208+
def available_size(self):
209+
return len(self.free_slots)
210+
211+
def get_kvcache(self):
212+
return self._kvcache
213+
214+
def alloc(self, need_size: int):
215+
if need_size > len(self.free_slots):
216+
return None
217+
218+
select_index = self.free_slots[:need_size]
219+
self.free_slots = self.free_slots[need_size:]
220+
return select_index
221+
222+
def free(self, free_index: torch.Tensor):
223+
if free_index.numel() == 0:
224+
return
225+
226+
if self.is_not_in_free_group:
227+
self.free_slots = torch.concat((self.free_slots, free_index))
228+
else:
229+
self.free_group.append(free_index)
230+
231+
def free_group_begin(self):
232+
self.is_not_in_free_group = False
233+
self.free_group = []
234+
235+
def free_group_end(self):
236+
self.is_not_in_free_group = True
237+
if self.free_group:
238+
self.free(torch.concat(self.free_group))
239+
240+
def clear(self):
241+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
242+
self.free_slots = torch.arange(
243+
1, self.size + 1, dtype=torch.int64, device=self.device
244+
)
245+
self.is_in_free_group = False
246+
self.free_group = []
247+
248+
147249
class MHATokenToKVPool(BaseTokenToKVPool):
148250
def __init__(
149251
self,

scratchpad/model_executor/model_runner.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,10 @@
2929
MLATokenToKVPool,
3030
ReqToTokenPool,
3131
)
32-
from scratchpad.memory.het_pool import (
32+
from scratchpad.memory import (
3333
HeterogeneousMHATokenToKVPool,
34+
ReqToTokenPool,
35+
TokenToKVPoolAllocator,
3436
)
3537
from scratchpad.model_executor.forward_info import ForwardBatch
3638
from scratchpad.model_executor.speculative.spec_info import SpeculativeAlgorithm
@@ -59,6 +61,8 @@ def __init__(
5961
tp_size: int,
6062
nccl_port: int,
6163
server_args: ServerArgs,
64+
req_to_token_pool: Optional[ReqToTokenPool] = None,
65+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
6266
):
6367
# Parse args
6468
self.model_config = model_config
@@ -72,6 +76,8 @@ def __init__(
7276
self.is_generation = model_config.is_generation
7377
self.is_multimodal = model_config.is_multimodal
7478
self.spec_algorithm = SpeculativeAlgorithm.NONE
79+
self.req_to_token_pool = req_to_token_pool
80+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
7581
logger.info(f"model config: {model_config}")
7682
# Model-specific adjustment
7783
if (

scratchpad/model_executor/utils.py

Lines changed: 103 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,27 @@
11
import torch
22
import glob
33
import contextlib
4-
from typing import List, Generator, Tuple, Type
4+
from typing import List, Generator, Tuple, Type, Protocol
55
from tqdm import tqdm
66
import json
77
import os
8-
from scratchpad.utils import snapshot_download, get_lock, DisabledTqdm
8+
from scratchpad.utils import (
9+
snapshot_download,
10+
get_lock,
11+
DisabledTqdm,
12+
is_pin_memory_available,
13+
)
914
from safetensors.torch import safe_open
1015
from scratchpad.nn.models import ModelRegistry
1116
from scratchpad.config import ModelConfig, LoadConfig
1217
from scratchpad.nn.quantization import get_quantization_config, QuantizationConfig
1318
import huggingface_hub
1419
from torch import nn
20+
from torch.func import functional_call
1521

1622
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501
23+
_CPU_OFFLOAD_BYTES = 0
24+
_CPU_OFFLOAD_MAX_BYTES = 0
1725

1826

1927
@contextlib.contextmanager
@@ -208,3 +216,96 @@ def get_quant_config(
208216
)
209217

210218
return quant_cls.from_config(config)
219+
220+
221+
def set_cpu_offload_max_bytes(max_bytes: int) -> None:
222+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
223+
_CPU_OFFLOAD_BYTES = 0
224+
_CPU_OFFLOAD_MAX_BYTES = max_bytes
225+
226+
227+
def maybe_offload_to_cpu(module: torch.nn.Module) -> torch.nn.Module:
228+
device = next(module.parameters()).device
229+
230+
if device == torch.device("cpu"):
231+
return module
232+
233+
global _CPU_OFFLOAD_MAX_BYTES, _CPU_OFFLOAD_BYTES
234+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
235+
return module
236+
237+
pin_memory = is_pin_memory_available()
238+
# offload parameters to CPU
239+
# use pin_memory if possible, which helps cudagraph capture speed
240+
offloaded_parameters = False
241+
for p in module.parameters():
242+
if _CPU_OFFLOAD_BYTES >= _CPU_OFFLOAD_MAX_BYTES:
243+
# we use per-parameter offloading
244+
# one module might have some parameters offloaded and some not
245+
break
246+
247+
# `torch.empty_like` does not support `pin_memory` argument
248+
cpu_data = torch.empty_strided(
249+
size=p.data.size(),
250+
stride=p.data.stride(),
251+
dtype=p.data.dtype,
252+
layout=p.data.layout,
253+
device="cpu",
254+
pin_memory=pin_memory,
255+
)
256+
cpu_data.copy_(p.data)
257+
p.data = cpu_data
258+
_CPU_OFFLOAD_BYTES += p.data.numel() * p.data.element_size()
259+
offloaded_parameters = True
260+
261+
if offloaded_parameters:
262+
original_forward = module.forward
263+
264+
def forward(*args, **kwargs):
265+
module.forward = original_forward
266+
device_state = {
267+
# here we blindly call `to(device)`
268+
# if the parameter is already on the device, it will be a no-op
269+
k: v.to(device, non_blocking=True)
270+
for k, v in module.state_dict().items()
271+
}
272+
output = functional_call(module, device_state, args=args, kwargs=kwargs)
273+
module.forward = forward
274+
return output
275+
276+
module.forward = forward
277+
278+
return module
279+
280+
281+
class LayerFn(Protocol):
282+
def __call__(self, layer_id: int, prefix: str) -> torch.nn.Module:
283+
...
284+
285+
286+
def add_prefix(name: str, prefix: str) -> str:
287+
"""Add a weight path prefix to a module name.
288+
289+
Args:
290+
name: base module name.
291+
prefix: weight prefix str to added to the front of `name` concatenated with `.`.
292+
293+
Returns:
294+
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
295+
"""
296+
return name if not prefix else f"{prefix}.{name}"
297+
298+
299+
def make_layers(
300+
num_hidden_layers: int,
301+
layer_fn: LayerFn,
302+
prefix: str = "",
303+
) -> Tuple[int, int, torch.nn.ModuleList]:
304+
"""Make a list of layers with the given layer function"""
305+
modules = torch.nn.ModuleList(
306+
[
307+
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
308+
for idx in range(num_hidden_layers)
309+
]
310+
)
311+
return modules

tools/benchmark/bench_perf.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ async def run_benchmark(
6464
goodput_config_dict: Dict[str, float],
6565
max_concurrency: Optional[int] = None,
6666
):
67-
system_info = await async_request_sp_sysinfo(args.endpoint)
67+
# system_info = await async_request_sp_sysinfo(args.endpoint)
6868
pbar = tqdm(total=len(input_requests))
6969
tasks: List[asyncio.Task] = []
7070
semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
@@ -101,7 +101,7 @@ async def limited_request_func(request_func_input, pbar):
101101
output_file = write_benchmark(
102102
metrics,
103103
args.output,
104-
system_info,
104+
{},
105105
args,
106106
outputs,
107107
)
@@ -131,6 +131,7 @@ def benchmark(args):
131131
except Exception as e:
132132
print("Server is not ready. Please start the server first.")
133133
time.sleep(5)
134+
print(f"Server is ready. Starting benchmark...")
134135
asyncio.run(
135136
run_benchmark(
136137
args,
@@ -209,7 +210,7 @@ def benchmark(args):
209210
"--wait-until-ready",
210211
action="store_true",
211212
help="Wait until the server is ready before starting the benchmark.",
212-
default=True,
213+
default=False,
213214
)
214215
args = parser.parse_args()
215216
benchmark(args)

0 commit comments

Comments
 (0)