Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions python/sglang/srt/configs/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,8 @@ def __init__(
dtype: str = "auto",
quantization: Optional[str] = None,
modelopt_quant: Optional[Union[str, Dict]] = None,
modelopt_checkpoint_restore_path: Optional[str] = None,
modelopt_checkpoint_save_path: Optional[str] = None,
override_config_file: Optional[str] = None,
is_draft_model: bool = False,
hybrid_kvcache_ratio: Optional[float] = None,
Expand Down
134 changes: 108 additions & 26 deletions python/sglang/srt/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import time
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor
from contextlib import contextmanager
from contextlib import contextmanager, suppress
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -30,7 +30,6 @@
Tuple,
cast,
)
from urllib.parse import urlparse

import huggingface_hub
import numpy as np
Expand All @@ -52,7 +51,7 @@

from huggingface_hub import HfApi, hf_hub_download
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME

from sglang.srt.configs.load_config import LoadConfig, LoadFormat
Expand Down Expand Up @@ -104,6 +103,7 @@
get_device_capability,
is_npu,
is_pin_memory_available,
rank0_log,
set_weight_attrs,
)

Expand Down Expand Up @@ -545,7 +545,7 @@ def _load_modelopt_base_model(self, model_config: ModelConfig) -> nn.Module:
**model_kwargs,
trust_remote_code=True,
)
logger.info(f"ModelOpt quantization requested: {model_config.modelopt_quant}")
rank0_log(f"ModelOpt quantization requested: {model_config.modelopt_quant}")

quant_choice_str = model_config.modelopt_quant
if not isinstance(quant_choice_str, str):
Expand Down Expand Up @@ -1764,6 +1764,96 @@ def __init__(self, load_config: LoadConfig):
super().__init__(load_config)
# Any ModelOpt specific initialization if needed

def _setup_modelopt_quantization(
self,
model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path: str | None = None,
quantized_ckpt_save_path: str | None = None,
) -> None:
"""
Set up ModelOpt quantization for the given model.

Args:
model: The model to quantize
tokenizer: The tokenizer associated with the model
quant_cfg: The quantization configuration
quantized_ckpt_restore_path: Path to restore quantized checkpoint from
quantized_ckpt_save_path: Path to save quantized checkpoint to

Raises:
ImportError: If ModelOpt is not available
Exception: If quantization setup fails
"""
try:
import modelopt.torch.opt as mto
import modelopt.torch.quantization as mtq
from modelopt.torch.quantization.utils import is_quantized
except ImportError as e:
raise ImportError(
"ModelOpt is not available. Please install modelopt."
) from e

if is_quantized(model):
rank0_log("Model is already quantized, skipping quantization setup.")
return
# Restore from checkpoint if provided
if quantized_ckpt_restore_path:
try:
mto.restore(model, quantized_ckpt_restore_path)
rank0_log(
f"Restored quantized model from {quantized_ckpt_restore_path}"
)
return
except Exception as e:
logger.warning(
f"Failed to restore from {quantized_ckpt_restore_path}: {e}"
)
rank0_log("Proceeding with calibration-based quantization...")

# Set up calibration-based quantization
try:
# Left padding tends to work better for batched generation with decoder-only LMs
with suppress(Exception):
tokenizer.padding_side = "left"

from modelopt.torch.utils.dataset_utils import (
create_forward_loop,
get_dataset_dataloader,
)

# Create calibration dataloader
calib_dataloader = get_dataset_dataloader(
dataset_name="cnn_dailymail", # TODO: Consider making this configurable
tokenizer=tokenizer,
batch_size=36, # TODO: Consider making this configurable
num_samples=512, # TODO: Consider making this configurable
device=model.device,
include_labels=False,
)

calibrate_loop = create_forward_loop(dataloader=calib_dataloader)

# Apply quantization
mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)

if get_tensor_model_parallel_rank() == 0:
mtq.print_quant_summary(model)

# Save checkpoint if path provided
if quantized_ckpt_save_path:
try:
mto.save(model, quantized_ckpt_save_path)
rank0_log(f"Quantized model saved to {quantized_ckpt_save_path}")
except Exception as e:
logger.warning(
f"Failed to save quantized checkpoint to {quantized_ckpt_save_path}: {e}"
)

except Exception as e:
raise Exception(f"Failed to set up ModelOpt quantization: {e}") from e

def load_model(
self,
*,
Expand All @@ -1779,7 +1869,6 @@ def load_model(
# Import ModelOpt modules (already done in _load_modelopt_base_model, but needed here for quantization)
try:
import modelopt.torch.quantization as mtq
from modelopt.torch.utils.dataset_utils import create_forward_loop
except ImportError:
logger.error(
"NVIDIA Model Optimizer (modelopt) library not found. "
Expand Down Expand Up @@ -1808,33 +1897,26 @@ def load_model(
"Please verify QUANT_CFG_CHOICES and the ModelOpt library."
)

# For now, assume no calibration. Calibration setup is a separate, more complex step.
use_calibration = False # This would ideally be a configurable parameter
calib_dataloader = None # This would need to be provided/configured

calibrate_loop = (
create_forward_loop(dataloader=calib_dataloader)
if use_calibration
else None
)

if use_calibration and calib_dataloader is None:
logger.warning(
"ModelOpt calibration requested but no calib_dataloader provided. "
"Proceeding without calibration. Quantization accuracy may be affected."
)

logger.info(
f"Quantizing model with ModelOpt using config attribute: mtq.{quant_cfg_name}"
)

quantized_ckpt_restore_path = model_config.modelopt_checkpoint_restore_path
quantized_ckpt_save_path = model_config.modelopt_checkpoint_save_path
tokenizer = AutoTokenizer.from_pretrained(
model_config.model_path, use_fast=True
)
try:
model = mtq.quantize(model, quant_cfg, forward_loop=calibrate_loop)
logger.info("Model successfully quantized with ModelOpt.")
self._setup_modelopt_quantization(
model,
tokenizer,
quant_cfg,
quantized_ckpt_restore_path=quantized_ckpt_restore_path,
quantized_ckpt_save_path=quantized_ckpt_save_path,
)
except Exception as e:
logger.error(f"Error during ModelOpt mtq.quantize call: {e}")
raise
mtq.print_quant_summary(model)
logger.warning(f"ModelOpt quantization failed: {e}")
rank0_log("Proceeding without quantization...")

return model.eval()

Expand Down
17 changes: 17 additions & 0 deletions python/sglang/srt/server_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,8 @@ class ServerArgs:
model_loader_extra_config: str = "{}"
trust_remote_code: bool = False
modelopt_quant: Optional[Union[str, Dict]] = None
modelopt_checkpoint_restore_path: Optional[str] = None
modelopt_checkpoint_save_path: Optional[str] = None
context_length: Optional[int] = None
is_embedding: bool = False
enable_multimodal: Optional[bool] = None
Expand Down Expand Up @@ -1504,6 +1506,21 @@ def add_cli_args(parser: argparse.ArgumentParser):
"Supported values: 'fp8', 'int4_awq', 'w4a8_awq', 'nvfp4', 'nvfp4_awq'. "
"This requires the NVIDIA Model Optimizer library to be installed: pip install nvidia-modelopt",
)
parser.add_argument(
"--modelopt-checkpoint-restore-path",
type=str,
default=ServerArgs.modelopt_checkpoint_restore_path,
help="Path to restore a previously saved ModelOpt quantized checkpoint. "
"If provided, the quantization process will be skipped and the model "
"will be loaded from this checkpoint.",
)
parser.add_argument(
"--modelopt-checkpoint-save-path",
type=str,
default=ServerArgs.modelopt_checkpoint_save_path,
help="Path to save the ModelOpt quantized checkpoint after quantization. "
"This allows reusing the quantized model in future runs.",
)
parser.add_argument(
"--kv-cache-dtype",
type=str,
Expand Down
Loading