Skip to content

Commit 2ade48f

Browse files
authored
Models/gemma 3 (#47)
* wip: revamp model registration * fix gemma3 for causal LM * gemma3 * update dockerfile
1 parent 4349939 commit 2ade48f

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

67 files changed

+6354
-2398
lines changed

docker/Dockerfile.x86_64-cuda

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,9 +16,7 @@ WORKDIR /scratchpad
1616

1717
COPY . /scratchpad
1818

19-
RUN git clone -b v0.1.6 https://github.com/flashinfer-ai/flashinfer.git --recursive && \
20-
cd flashinfer/python && \
21-
pip install --no-build-isolation --verbose --editable .
19+
RUN pip install flashinfer-python==0.2.3 -i https://flashinfer.ai/whl/cu124/torch2.5/
2220

2321
RUN git clone https://github.com/eth-easl/triteia.git && \
2422
cd triteia && \

meta/requirements.txt

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ requests
66
uvicorn
77
zmq
88
huggingface_hub
9-
transformers==4.46.3
9+
transformers
1010
outlines==0.0.46
1111
uvloop
1212
nvidia-ml-py
@@ -23,4 +23,8 @@ orjson
2323
xgrammar>=0.1.13
2424
nvidia-cuda-nvrtc-cu12
2525
cuda-python
26-
setproctitle
26+
setproctitle
27+
torch-memory-saver
28+
sgl-kernel
29+
decord
30+
soundfile

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ max-line-length = 120
1313

1414
[project.scripts]
1515
sp = "scratchpad.cli.sp:app"
16-
16+
spc = "scratchpad.cli.spc:app"
1717
[tool.setuptools]
1818
packages = ["scratchpad"]
1919

scratchpad/cli/handlers/chat.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def __init__(
5757
self.cache_prompt = cache_prompt
5858
self.headers = {"Content-Type": "application/json"}
5959
self.chat_history = []
60-
self.model_name = ""
60+
self.model_name = model_name
6161
self.console = Console()
6262
self.client = openai.OpenAI(api_key="test", base_url=self.serveraddr + "/v1")
6363
# TODO: Gracefully handle user input history file.
@@ -77,6 +77,7 @@ def chat_generator(self, prompt):
7777
"seed": self.seed,
7878
"model": self.model_name,
7979
}
80+
print(f"Payload: {payload}")
8081
try:
8182
response = self.client.chat.completions.create(**payload)
8283
for chunk in response:

scratchpad/cli/sp.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import typer
2-
from scratchpad.server import dataclass_to_cli, ServerArgs, launch_server
2+
from scratchpad.server import dataclass_to_cli, ServerArgs
33
from .handlers import ChatHandler, benchmark_quality
44

55
app = typer.Typer()
@@ -13,6 +13,7 @@ def serve(
1313
):
1414
"""Spin up the server"""
1515
from scratchpad.server.args import global_args
16+
from scratchpad.server import launch_server
1617
import multiprocessing as mp
1718

1819
mp.set_start_method("spawn", force=True)
@@ -30,7 +31,7 @@ def version():
3031
@app.command()
3132
def chat(
3233
model: str,
33-
backend: str = "http://localhost:8080",
34+
backend: str = "http://localhost:3000",
3435
):
3536
chat_handler = ChatHandler(server_addr=backend, model_name=model)
3637
chat_handler.chat()

scratchpad/cli/spc.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
import typer
2+
3+
app = typer.Typer()
4+
5+
6+
@app.command()
7+
def chat(
8+
model: str,
9+
backend: str = "http://localhost:3000",
10+
):
11+
from .handlers import ChatHandler
12+
13+
print(f"Chatting with model: {model}, backend: {backend}")
14+
chat_handler = ChatHandler(server_addr=backend, model_name=model)
15+
chat_handler.chat()
16+
17+
18+
@app.command()
19+
def version():
20+
"""Print the version"""
21+
typer.echo("0.1.0")

scratchpad/config/model_config.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import IntEnum, auto
22
from typing import Optional
3-
from typing import List
3+
from typing import List, Set
44
from transformers import PretrainedConfig
55

66
from scratchpad.utils import get_config, get_context_length
@@ -70,6 +70,9 @@ def __init__(
7070
self.hf_config.architectures, is_embedding
7171
)
7272
self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
73+
self.is_multimodal_gen = False
74+
self.is_image_gen = False
75+
self.is_audio_model = False
7376
self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
7477
if context_length is not None:
7578
self.context_len = context_length
@@ -82,38 +85,26 @@ def __init__(
8285
"head_dim",
8386
self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
8487
)
85-
86-
# FIXME: temporary special judge for deepseek v2 MLA architecture
87-
if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
88-
self.head_dim = 256
89-
self.attention_arch = AttentionArch.MLA
90-
self.kv_lora_rank = self.hf_config.kv_lora_rank
91-
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
92-
elif "MiniCPM3ForCausalLM" in self.hf_config.architectures:
93-
self.head_dim = 128
94-
self.attention_arch = AttentionArch.MLA
95-
self.kv_lora_rank = self.hf_config.kv_lora_rank
96-
self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
97-
else:
98-
self.attention_arch = AttentionArch.MHA
99-
88+
self.attention_arch = AttentionArch.MHA
10089
self.num_attention_heads = self.hf_text_config.num_attention_heads
10190
self.num_key_value_heads = getattr(
10291
self.hf_text_config, "num_key_value_heads", None
10392
)
104-
105-
# for Dbrx and MPT models
106-
if self.hf_config.model_type in ["dbrx", "mpt"]:
107-
self.num_key_value_heads = getattr(
108-
self.hf_config.attn_config, "kv_n_heads", None
109-
)
110-
11193
if self.num_key_value_heads is None:
11294
self.num_key_value_heads = self.num_attention_heads
11395
self.hidden_size = self.hf_text_config.hidden_size
11496
self.num_hidden_layers = self.hf_text_config.num_hidden_layers
11597
self.vocab_size = self.hf_text_config.vocab_size
11698
self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
99+
self.hf_eos_token_id = self.get_hf_eos_token_id()
100+
self.image_token_id = getattr(self.hf_config, "image_token_id", None)
101+
102+
def get_hf_eos_token_id(self) -> Optional[Set[int]]:
103+
eos_ids = getattr(self.hf_config, "eos_token_id", None)
104+
if eos_ids:
105+
# it can be either int or list of int
106+
eos_ids = {eos_ids} if isinstance(eos_ids, int) else set(eos_ids)
107+
return eos_ids
117108

118109
# adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
119110
def get_total_num_kv_heads(self) -> int:

scratchpad/config/utils.py

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,13 @@ def _get_and_verify_dtype(
2626
dtype = dtype.lower()
2727
if dtype == "auto":
2828
if config_dtype == torch.float32:
29-
if config.model_type == "gemma2":
29+
if config.model_type.startswith("gemma"):
30+
if config.model_type == "gemma":
31+
gemma_version = ""
32+
else:
33+
gemma_version = config.model_type[5]
3034
logger.info(
31-
"For Gemma 2, we downcast float32 to bfloat16 instead "
35+
f"For Gemma {gemma_version}, we downcast float32 to bfloat16 instead "
3236
"of float16 by default. Please specify `dtype` if you "
3337
"want to use float16."
3438
)
@@ -65,6 +69,13 @@ def _get_and_verify_dtype(
6569
return torch_dtype
6670

6771

72+
def get_min_sliding_window(sliding_window: Union[int, list[Optional[int]]]) -> int:
73+
if isinstance(sliding_window, list):
74+
return min(s for s in sliding_window if s is not None)
75+
76+
return sliding_window
77+
78+
6879
def _get_and_verify_max_len(
6980
hf_config: PretrainedConfig,
7081
max_model_len: Optional[int],
@@ -216,3 +227,21 @@ def get_served_model_name(
216227
if isinstance(served_model_name, list):
217228
return served_model_name[0]
218229
return served_model_name
230+
231+
232+
multimodal_model_archs = [
233+
"Gemma3ForConditionalGeneration",
234+
"MllamaForConditionalGeneration",
235+
"Qwen2VLForConditionalGeneration",
236+
"Qwen2_5_VLForConditionalGeneration",
237+
]
238+
239+
240+
def is_multimodal_model(model_architectures: List[str]):
241+
if any(
242+
multi_model_arch in model_architectures
243+
for multi_model_arch in multimodal_model_archs
244+
):
245+
return True
246+
else:
247+
return False

scratchpad/config/vllm_model_config.py

Lines changed: 9 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,14 @@
11
import enum
2-
from dataclasses import dataclass, field, fields
3-
import contextlib
42
from typing import (
5-
TYPE_CHECKING,
63
Any,
74
Dict,
85
List,
96
Mapping,
107
Optional,
118
Union,
129
)
13-
from pathlib import Path
1410
import torch
11+
from pathlib import Path
1512
from transformers import PretrainedConfig, AutoConfig
1613
from scratchpad.utils import (
1714
logger,
@@ -20,12 +17,17 @@
2017
get_hf_image_processor_config,
2118
)
2219
from scratchpad.config.modality_config import MultiModalConfig
23-
from scratchpad.nn.models import ModelRegistry
24-
from .utils import _get_and_verify_dtype, _get_and_verify_max_len, get_served_model_name
2520
import huggingface_hub
2621
from huggingface_hub import file_exists, try_to_load_from_cache
2722
from transformers.utils import CONFIG_NAME as HF_CONFIG_NAME
2823

24+
from .utils import (
25+
_get_and_verify_dtype,
26+
_get_and_verify_max_len,
27+
get_served_model_name,
28+
is_multimodal_model,
29+
)
30+
2931

3032
class ConfigFormat(str, enum.Enum):
3133
AUTO = "auto"
@@ -307,15 +309,14 @@ def __init__(
307309
self._verify_tokenizer_mode()
308310

309311
self.override_neuron_config = None
310-
self._verify_embedding_mode()
311312
self._verify_cuda_graph()
312313
self._verify_bnb_config()
313314

314315
def _init_multimodal_config(
315316
self, limit_mm_per_prompt: Optional[Mapping[str, int]]
316317
) -> Optional["MultiModalConfig"]:
317318
architectures = getattr(self.hf_config, "architectures", [])
318-
if any(ModelRegistry.is_multimodal_model(arch) for arch in architectures):
319+
if is_multimodal_model(architectures):
319320
return MultiModalConfig(limit_per_prompt=limit_mm_per_prompt or {})
320321
else:
321322
if limit_mm_per_prompt:
@@ -333,12 +334,6 @@ def _verify_tokenizer_mode(self) -> None:
333334
)
334335
self.tokenizer_mode = tokenizer_mode
335336

336-
def _verify_embedding_mode(self) -> None:
337-
architectures = getattr(self.hf_config, "architectures", [])
338-
self.embedding_mode = any(
339-
ModelRegistry.is_embedding_model(arch) for arch in architectures
340-
)
341-
342337
def _parse_quant_hf_config(self):
343338
quant_cfg = getattr(self.hf_config, "quantization_config", None)
344339
if quant_cfg is None:
@@ -526,11 +521,6 @@ def is_encoder_decoder_model(self) -> bool:
526521
)
527522
)
528523

529-
@property
530-
def is_embedding_model(self) -> bool:
531-
"""Extract the embedding model flag."""
532-
return self.embedding_mode
533-
534524
@property
535525
def is_multimodal_model(self) -> bool:
536526
return self.multimodal_config is not None

scratchpad/constrained/xgrammar_backend.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,4 @@
1-
import logging
21
from typing import List, Tuple
3-
42
import torch
53
from xgrammar import (
64
CompiledGrammar,
@@ -13,8 +11,7 @@
1311
)
1412

1513
from .base_backend import BaseGrammarObject, BaseGrammarBackend
16-
17-
logger = logging.getLogger(__name__)
14+
from scratchpad.utils import logger
1815

1916

2017
MAX_ROLLBACK_TOKENS = 200
@@ -104,23 +101,23 @@ def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
104101
else:
105102
ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
106103
except RuntimeError as e:
107-
logging.warning(
104+
logger.warning(
108105
f"Skip invalid json_schema: json_schema={key_string}, {e=}"
109106
)
110107
return None
111108
elif key_type == "ebnf":
112109
try:
113110
ctx = self.grammar_compiler.compile_grammar(key_string)
114111
except RuntimeError as e:
115-
logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
112+
logger.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
116113
return None
117114
elif key_type == "regex":
118115
try:
119116
ctx = self.grammar_compiler.compile_grammar(
120117
Grammar.from_regex(key_string)
121118
)
122119
except RuntimeError as e:
123-
logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
120+
logger.warning(f"Skip invalid regex: regex={key_string}, {e=}")
124121
return None
125122
else:
126123
raise ValueError(f"Invalid key_type: {key_type}")

0 commit comments

Comments
 (0)