Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -1277,13 +1277,16 @@ class LoadConfig:
ignore_patterns: The list of patterns to ignore when loading the model.
Default to "original/**/*" to avoid repeated loading of llama's
checkpoints.
use_tqdm_on_load: Whether to enable tqdm for showing progress bar during
loading. Default to True
"""

load_format: Union[str, LoadFormat, "BaseModelLoader"] = LoadFormat.AUTO
download_dir: Optional[str] = None
model_loader_extra_config: Optional[Union[str, dict]] = field(
default_factory=dict)
ignore_patterns: Optional[Union[list[str], str]] = None
use_tqdm_on_load: bool = True

def compute_hash(self) -> str:
"""
Expand Down
10 changes: 10 additions & 0 deletions vllm/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ class EngineArgs:
additional_config: Optional[Dict[str, Any]] = None
enable_reasoning: Optional[bool] = None
reasoning_parser: Optional[str] = None
use_tqdm_on_load: bool = True

def __post_init__(self):
if not self.tokenizer:
Expand Down Expand Up @@ -751,6 +752,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser:
default=1,
help=('Maximum number of forward steps per '
'scheduler call.'))
parser.add_argument(
'--use-tqdm-on-load',
dest='use_tqdm_on_load',
action=argparse.BooleanOptionalAction,
default=EngineArgs.use_tqdm_on_load,
help='Whether to enable/disable progress bar '
'when loading model weights.',
)

parser.add_argument(
'--multi-step-stream-outputs',
Expand Down Expand Up @@ -1179,6 +1188,7 @@ def create_load_config(self) -> LoadConfig:
download_dir=self.download_dir,
model_loader_extra_config=self.model_loader_extra_config,
ignore_patterns=self.ignore_patterns,
use_tqdm_on_load=not self.use_tqdm_on_load,
)

def create_engine_config(self,
Expand Down
26 changes: 21 additions & 5 deletions vllm/model_executor/model_loader/loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -354,11 +354,18 @@ def _get_weights_iterator(
self.load_config.download_dir,
hf_folder,
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
elif use_safetensors:
weights_iterator = safetensors_weights_iterator(hf_weights_files)
weights_iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
weights_iterator = pt_weights_iterator(hf_weights_files)
weights_iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)

if current_platform.is_tpu():
# In PyTorch XLA, we should call `xm.mark_step` frequently so that
Expand Down Expand Up @@ -806,9 +813,15 @@ def _prepare_weights(self, model_name_or_path: str,

def _hf_weight_iter(self, hf_weights_files, use_safetensors: bool):
if use_safetensors:
iterator = safetensors_weights_iterator(hf_weights_files)
iterator = safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
else:
iterator = pt_weights_iterator(hf_weights_files)
iterator = pt_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)
for org_name, param in iterator:
# mapping weight names from transformers to vllm while preserving
# original names.
Expand Down Expand Up @@ -1396,7 +1409,10 @@ def _get_weights_iterator(
revision: str) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Get an iterator for the model weights based on the load format."""
hf_weights_files = self._prepare_weights(model_or_path, revision)
return runai_safetensors_weights_iterator(hf_weights_files)
return runai_safetensors_weights_iterator(
hf_weights_files,
self.load_config.use_tqdm_on_load,
)

def download_model(self, model_config: ModelConfig) -> None:
"""Download model if necessary"""
Expand Down
35 changes: 18 additions & 17 deletions vllm/model_executor/model_loader/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -366,16 +366,22 @@ def filter_files_not_needed_for_inference(
_BAR_FORMAT = "{desc}: {percentage:3.0f}% Completed | {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}]\n" # noqa: E501


def enable_tqdm(use_tqdm_on_load: bool):
return use_tqdm_on_load and (not torch.distributed.is_initialized()
or torch.distributed.get_rank() == 0)


def np_cache_weights_iterator(
model_name_or_path: str, cache_dir: Optional[str], hf_folder: str,
hf_weights_files: List[str]
model_name_or_path: str,
cache_dir: Optional[str],
hf_folder: str,
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model np files.

Will dump the model weights to numpy files if they are not already dumped.
"""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
# Convert the model weights from torch tensors to numpy arrays for
# faster loading.
np_folder = os.path.join(hf_folder, "np")
Expand All @@ -389,7 +395,7 @@ def np_cache_weights_iterator(
for bin_file in tqdm(
hf_weights_files,
desc="Loading np_cache checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file,
Expand All @@ -414,15 +420,14 @@ def np_cache_weights_iterator(


def safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str],
use_tqdm_on_load: bool,
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
with safe_open(st_file, framework="pt") as f:
Expand All @@ -432,32 +437,28 @@ def safetensors_weights_iterator(


def runai_safetensors_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str], use_tqdm_on_load: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model safetensor files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
with SafetensorsStreamer() as streamer:
for st_file in tqdm(
hf_weights_files,
desc="Loading safetensors using Runai Model Streamer",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
streamer.stream_file(st_file)
yield from streamer.get_tensors()


def pt_weights_iterator(
hf_weights_files: List[str]
hf_weights_files: List[str], use_tqdm_on_load: bool
) -> Generator[Tuple[str, torch.Tensor], None, None]:
"""Iterate over the weights in the model bin/pt files."""
enable_tqdm = not torch.distributed.is_initialized(
) or torch.distributed.get_rank() == 0
for bin_file in tqdm(
hf_weights_files,
desc="Loading pt checkpoint shards",
disable=not enable_tqdm,
disable=not enable_tqdm(use_tqdm_on_load),
bar_format=_BAR_FORMAT,
):
state = torch.load(bin_file, map_location="cpu", weights_only=True)
Expand Down