Skip to content

Commit bb61a23

Browse files
committed
support qwen3 on nvidia
1 parent 24c2bff commit bb61a23

File tree

2 files changed

+506
-0
lines changed

2 files changed

+506
-0
lines changed

server/text_generation_server/models/__init__.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,6 +152,9 @@
152152
from text_generation_server.models.custom_modeling.flash_qwen2_modeling import (
153153
Qwen2ForCausalLM,
154154
)
155+
from text_generation_server.models.custom_modeling.flash_qwen3_modeling import (
156+
Qwen3ForCausalLM,
157+
)
155158
from text_generation_server.models.custom_modeling.flash_mistral_modeling import (
156159
FlashMistralForCausalLM,
157160
)
@@ -348,6 +351,11 @@ class ModelType(enum.Enum):
348351
"name": "Qwen 2",
349352
"url": "https://huggingface.co/collections/Qwen/qwen2-6659360b33528ced941e557f",
350353
}
354+
QWEN3 = {
355+
"type": "qwen3",
356+
"name": "Qwen 3",
357+
"url": "https://huggingface.co/collections/Qwen/qwen3-67c6c6f89c4f76621268bb6d",
358+
}
351359
QWEN2_VL = {
352360
"type": "qwen2_vl",
353361
"name": "Qwen 2 VL",
@@ -1470,6 +1478,40 @@ def get_model(
14701478
trust_remote_code=trust_remote_code,
14711479
)
14721480

1481+
if model_type == QWEN3:
1482+
if FLASH_ATTENTION:
1483+
return FlashCausalLM(
1484+
model_id=model_id,
1485+
model_class=Qwen3ForCausalLM,
1486+
revision=revision,
1487+
quantize=quantize,
1488+
speculator=speculator,
1489+
dtype=dtype,
1490+
kv_cache_dtype=kv_cache_dtype,
1491+
trust_remote_code=trust_remote_code,
1492+
lora_adapter_ids=lora_adapter_ids,
1493+
)
1494+
elif FLASH_TRANSFORMERS_BACKEND:
1495+
return TransformersFlashCausalLM.fallback(
1496+
model_id,
1497+
revision,
1498+
quantize=quantize,
1499+
speculator=speculator,
1500+
dtype=dtype,
1501+
trust_remote_code=trust_remote_code,
1502+
)
1503+
elif sharded:
1504+
raise NotImplementedError(FLASH_ATT_ERROR_MESSAGE.format("Sharded Qwen3"))
1505+
else:
1506+
return CausalLM.fallback(
1507+
model_id,
1508+
revision,
1509+
quantize=quantize,
1510+
speculator=speculator,
1511+
dtype=dtype,
1512+
trust_remote_code=trust_remote_code,
1513+
)
1514+
14731515
if model_type == OPT:
14741516
return CausalLM(
14751517
model_id=model_id,

0 commit comments

Comments
 (0)