Skip to content
14 changes: 11 additions & 3 deletions lmdeploy/cli/serve.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,22 +165,23 @@ def add_parser_api_server():
max_prefill_token_num_act = ArgumentHelper.max_prefill_token_num(pt_group)
quant_policy = ArgumentHelper.quant_policy(pt_group)
model_format = ArgumentHelper.model_format(pt_group)
ArgumentHelper.dp(pt_group)
dp_act = ArgumentHelper.dp(pt_group)
num_nodes_act = ArgumentHelper.num_nodes(pt_group)
ArgumentHelper.ep(pt_group)
ArgumentHelper.enable_microbatch(pt_group)
ArgumentHelper.enable_eplb(pt_group)
ArgumentHelper.enable_metrics(pt_group)
ArgumentHelper.role(pt_group)
ArgumentHelper.migration_backend(pt_group)
# multi-node serving args
ArgumentHelper.node_rank(parser)
ArgumentHelper.num_nodes(parser)
node_rank_act = ArgumentHelper.node_rank(pt_group)

# turbomind args
tb_group = parser.add_argument_group('TurboMind engine arguments')
# common engine args
tb_group._group_actions.append(dtype_act)
tb_group._group_actions.append(tp_act)
tb_group._group_actions.append(dp_act)
tb_group._group_actions.append(session_len_act)
tb_group._group_actions.append(max_batch_size_act)
tb_group._group_actions.append(cache_max_entry_act)
Expand All @@ -189,10 +190,13 @@ def add_parser_api_server():
tb_group._group_actions.append(max_prefill_token_num_act)
tb_group._group_actions.append(quant_policy)
tb_group._group_actions.append(model_format)
tb_group._group_actions.append(num_nodes_act)
tb_group._group_actions.append(node_rank_act)
ArgumentHelper.rope_scaling_factor(tb_group)
ArgumentHelper.num_tokens_per_iter(tb_group)
ArgumentHelper.max_prefill_iters(tb_group)
ArgumentHelper.communicator(tb_group)
ArgumentHelper.ngpus_per_node(tb_group)

# vlm args
vision_group = parser.add_argument_group('Vision model arguments')
Expand Down Expand Up @@ -342,6 +346,10 @@ def api_server(args):
from lmdeploy.messages import TurbomindEngineConfig
backend_config = TurbomindEngineConfig(dtype=args.dtype,
tp=args.tp,
dp=args.dp,
nnodes=args.nnodes,
ngpus_per_node=args.ngpus_per_node,
node_rank=args.node_rank,
max_batch_size=max_batch_size,
session_len=args.session_len,
model_format=args.model_format,
Expand Down
6 changes: 6 additions & 0 deletions lmdeploy/cli/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,6 +192,12 @@ def num_nodes(parser):

return parser.add_argument('--nnodes', type=int, default=1, help='The total node nums')

@staticmethod
def ngpus_per_node(parser):
"""Add argument ngpus_per_node to parser."""

return parser.add_argument('--ngpus-per-node', type=int, default=None, help='The total gpu nums per node')

@staticmethod
def session_id(parser):
"""Add argument session_id to parser."""
Expand Down
4 changes: 4 additions & 0 deletions lmdeploy/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,10 @@ class TurbomindEngineConfig:
mlp_tp_size: int = None
mlp_dp_size: int = None
outer_dp_size: int = None
nnodes: int = 1
node_rank: int = 0
ngpus_per_node: Optional[int] = None
devices: List[int] = None
session_len: Optional[int] = None
max_batch_size: int = None
cache_max_entry_count: float = 0.8
Expand Down
38 changes: 27 additions & 11 deletions lmdeploy/turbomind/turbomind.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import copy
import json
import math
import os
import os.path as osp
import sys
from collections import defaultdict
Expand Down Expand Up @@ -84,6 +85,12 @@ def complete_parallel_config(cfg: TurbomindEngineConfig):


def update_parallel_config(cfg: TurbomindEngineConfig):
if cfg.nnodes > 1:
assert cfg.ngpus_per_node is not None or cfg.devices is not None
cfg.devices = cfg.devices or list(range(cfg.ngpus_per_node))
cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices)
cfg.device_num = cfg.device_num or len(cfg.devices) * cfg.nnodes

if not complete_parallel_config(cfg):
total = cfg.dp * cfg.tp
if not cfg.device_num:
Expand All @@ -105,6 +112,13 @@ def update_parallel_config(cfg: TurbomindEngineConfig):
assert cfg.attn_dp_size * cfg.attn_tp_size * cfg.outer_dp_size == cfg.device_num
cfg.devices = cfg.devices or list(range(cfg.device_num))

# update devices
if cfg.nnodes == 1:
cfg.devices = cfg.devices if cfg.devices else list(range(cfg.device_num))
cfg.ngpus_per_node = cfg.ngpus_per_node or len(cfg.devices)
# for simplicity, each node has dp
assert cfg.outer_dp_size * cfg.attn_dp_size % cfg.nnodes == 0
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is possible that the model does not fit into a single node



class TurboMind:
"""LMDeploy's inference engine.
Expand Down Expand Up @@ -141,8 +155,15 @@ def __init__(self,
f' greater than 0, but got {_engine_config.max_batch_size}'

update_parallel_config(_engine_config)

self.gpu_count = _engine_config.device_num
if _engine_config.nnodes > 1 and _engine_config.node_rank == 0:
from torch.distributed import TCPStore
master_addr = os.environ.get('LMDEPLOY_DP_MASTER_ADDR')
master_port = os.environ.get('LMDEPLOY_DP_MASTER_PORT')
assert master_addr is not None and master_port is not None, \
'LMDEPLOY_DP_MASTER_ADDR and LMDEPLOY_DP_MASTER_PORT should be set when using multi-node'
self.store = TCPStore(host_name=master_addr, port=int(master_port), is_master=True)

self.gpu_count = len(_engine_config.devices)
self.devices = _engine_config.devices

self.tokenizer = tokenizer
Expand Down Expand Up @@ -196,10 +217,8 @@ def _create_engine(self):
def _create_weight(self, model_comm):
"""Allocate weight buffer, load params if from_workspace."""

# TODO: support mpi
self.node_id = 0
self.node_num = 1
torch.cuda.synchronize()
engine_cfg = self.config_dict['engine_config']
self.node_id = engine_cfg['node_rank']

# create weight
def _create_weight_func(device_id):
Expand Down Expand Up @@ -394,6 +413,8 @@ def close(self):
del self._export_iter
if self.model_comm is not None:
self.model_comm = None
if hasattr(self, 'store'):
del self.store

def create_instance(self, cuda_stream_id=0):
"""Create a turbomind instance.
Expand Down Expand Up @@ -500,11 +521,6 @@ def __init__(self, tm_model: TurboMind, config: TurbomindModelConfig, cuda_strea
self.tm_model = tm_model
self.cuda_stream_id = cuda_stream_id

self.node_id = tm_model.node_id
self.gpu_count = tm_model.gpu_count

self.session_len = tm_model.session_len

# create model instances
lazy_init = self.tm_model.config_dict['engine_config'].get('empty_init', False)
self._model_inst = None if lazy_init else self._create_model_instance(0)
Expand Down
8 changes: 8 additions & 0 deletions src/turbomind/comm/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ if (BUILD_MULTI_GPU)
target_link_libraries(device_comm INTERFACE nccl_comm)
endif ()

add_subdirectory(gloo)
target_link_libraries(host_comm INTERFACE gloo_comm)

add_library(serialize STATIC serialize.cc)
target_link_libraries(serialize PRIVATE core)
set_property(TARGET serialize PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(host_comm INTERFACE serialize)

if (BUILD_TEST)
add_executable(test_comm test_comm.cu)
target_link_libraries(test_comm PRIVATE device_comm host_comm core pthread nvtx_utils)
Expand Down
29 changes: 29 additions & 0 deletions src/turbomind/comm/gloo/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) OpenMMLab. All rights reserved.
cmake_minimum_required(VERSION 3.8)

include(FetchContent)
FetchContent_Declare(
gloo
GIT_REPOSITORY https://github.com/pytorch/gloo.git
GIT_TAG c7b7b022c124d9643957d9bd55f57ac59fce8fa2 # pytorch-v2.8.0-rc4
)

# some settings of gloo,
set(GLOO_INSTALL OFF CACHE BOOL "" FORCE)
set(GLOO_STATIC_OR_SHARED STATIC CACHE STRING "" FORCE)
set(USE_NCCL OFF)
set(BUILD_TEST OFF)
FetchContent_MakeAvailable(gloo)

# gloo build doesn't add include directories as a target property...
target_include_directories(gloo PUBLIC
$<BUILD_INTERFACE:${gloo_SOURCE_DIR}>
$<BUILD_INTERFACE:${gloo_BINARY_DIR}> # config.h generated at cmake config time
)

add_library(gloo_comm STATIC
gloo_comm.cc
tcp_store.cc
)
set_property(TARGET gloo_comm PROPERTY POSITION_INDEPENDENT_CODE ON)
target_link_libraries(gloo_comm PUBLIC gloo logger)
Loading