Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 86 additions & 0 deletions areal/api/cli_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,6 +386,90 @@ class PPOActorConfig(TrainEngineConfig):
)


@dataclass
class vLLMConfig:
"""Configuration for vLLM runtime."""

model: str = ""
seed: int = 1
skip_tokenizer_init: bool = False
enforce_eager: bool = True
dtype: str = "bfloat16"
distributed_executor_backend = "mp"
# original
max_num_seqs: int = 256
# kv_cache_type: str = "auto"
block_size: int = 16
swap_space: int = 4
cpu_offload_gb: float = 0
max_seq_len_to_capture: int = 32768
disable_sliding_window: bool = True
# NOTE: Defaults max_model_len to 32k because a larger value
# will enable chunked prefill in vLLM, which will cause
# evalution performance degeneration.
max_model_len: int | None = 32768
enable_chunked_prefill: bool = False
# NOTE: Setting enable_prefix_caching to False
# because it will reuse the block after
# model weights are updated. Using v0.7.2 reset_prefix_cache
# will fix this issue.
enable_prefix_caching: bool = False
gpu_memory_utilization: float = 0.9
worker_extension_cls: str = (
"areal.thirdparty.vllm.vllm_worker_extension.VLLMWorkerExtension"
)
enable_sleep_mode: bool = False

@staticmethod
def build_args(
vllm_config: "vLLMConfig",
tp_size,
host,
port,
dist_init_addr: str | None = None,
):
args: Dict = conf_as_dict(vllm_config)
args = dict(
host=host,
port=port,
# Model and tokenizer
tokenizer=vllm_config.model,
load_format="auto",
trust_remote_code=True,
tensor_parallel_size=tp_size,
**args,
)
return args

@staticmethod
def build_cmd(
vllm_config: "vLLMConfig",
tp_size,
host,
port,
dist_init_addr: str | None = None,
):
args = vLLMConfig.build_args(
vllm_config=vllm_config,
tp_size=tp_size,
host=host,
port=port,
dist_init_addr=dist_init_addr,
)
# convert to flags
flags = []
for k, v in args.items():
if v is None or v is False or v == "":
continue
if v is True:
flags.append(f"--{k.replace('_','-')}")
elif isinstance(v, list):
flags.append(f"--{k.replace('_','-')} {' '.join(map(str, v))}")
else:
flags.append(f"--{k.replace('_','-')} {v}")
return f"python3 -m areal.thirdparty.vllm.areal_vllm_server {' '.join(flags)}"


@dataclass
class SGLangConfig:
"""Configuration for SGLang runtime. Refer to:
Expand Down Expand Up @@ -913,6 +997,7 @@ class BaseExperimentConfig:
default="",
metadata={"help": "Path to the tokenizer."},
)
weight_update_mode: str = field(default="disk")

train_dataset: DatasetConfig = field(default_factory=DatasetConfig)
valid_dataset: DatasetConfig | None = field(default=None)
Expand All @@ -923,6 +1008,7 @@ class BaseExperimentConfig:
recover: RecoverConfig = field(default_factory=RecoverConfig)

sglang: SGLangConfig = field(default_factory=SGLangConfig)
vllm: vLLMConfig = field(default_factory=vLLMConfig)
launcher: LauncherConfig = field(default_factory=LauncherConfig)

scheduler: SchedulerConfig = field(default_factory=SchedulerConfig)
Expand Down
2 changes: 1 addition & 1 deletion areal/api/io_struct.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def from_disk(
)

@classmethod
def from_fsdp_nccl(
def from_fsdp_xccl(
cls,
allocation_mode: AllocationMode,
fsdp_engine: "TrainEngine",
Expand Down
12 changes: 7 additions & 5 deletions areal/api/reward_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,13 +115,15 @@ async def __call__(self, *args, **kwargs) -> float:

loop = asyncio.get_event_loop()
try:
return await asyncio.wait_for(
loop.run_in_executor(
executor,
partial(self.reward_fn, *args, **kwargs),
),
future = loop.run_in_executor(
executor,
partial(self.reward_fn, *args, **kwargs),
)
reward = await asyncio.wait_for(
future,
timeout=self.timeout_seconds,
)
return reward
except asyncio.TimeoutError:
logger.warning(
f"Computing reward timeout after {self.timeout_seconds}s. Set reward to 0."
Expand Down
15 changes: 9 additions & 6 deletions areal/engine/base_hf_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def __init__(self, config: TrainEngineConfig):
self.model: torch.nn.Module
self.optimizer: torch.optim.Optimizer
self.tokenizer: PreTrainedTokenizerFast
self.processor: ProcessorMixin | None
self.processor: ProcessorMixin | None = None
# huggingface model config
self.model_config: PretrainedConfig
self._version: int = 0
Expand Down Expand Up @@ -116,15 +116,17 @@ def parallelism_group(self) -> dist.ProcessGroup:
return _get_default_group()

def create_process_group(self, parallel_strategy: ParallelStrategy | None = None):
# Required by NCCL weight update group for SGLang
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
backend = current_platform.communication_backend
if current_platform.communication_backend == "nccl":
# Required by NCCL weight update group for SGLang
os.environ["NCCL_CUMEM_ENABLE"] = "0"
os.environ["NCCL_NVLS_ENABLE"] = "0"
if not dist.is_initialized():
# TODO: Handle the condition when WORLD_SIZE and RANK is not set in launcher
# NOTE: device_id **SHOULD NOT** be passed into init_process_group,
# otherwise initializing the NCCL weight update group will be wrong!
dist.init_process_group(
backend=current_platform.communication_backend,
backend=backend,
timeout=NCCL_DEFAULT_TIMEOUT,
)
self.own_global_group = True
Expand Down Expand Up @@ -155,7 +157,8 @@ def create_device_model(self):
)

tik = time.perf_counter()
with torch.device(current_platform.device_type):
device = current_platform.device_type
with torch.device(device):
model = AutoModelForImageTextToText.from_pretrained(
pretrained_model_name_or_path=self.config.path,
trust_remote_code=True,
Expand Down
6 changes: 4 additions & 2 deletions areal/engine/fsdp_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,7 +225,7 @@ def _load_model_from_hf(self, path: str):
)

def upload_weights(self, meta: WeightUpdateMeta):
if meta.type == "nccl":
if meta.type == current_platform.communication_backend:
if not self.weight_update_group_initialized:
self._init_distributed_weight_update(meta)
self._update_weights_from_distributed(meta.nccl_param_specs)
Expand Down Expand Up @@ -278,7 +278,9 @@ def _update_weights_from_distributed(
tensor = param.data
if dist.get_rank() == 0:
self.logger.debug(f"Broadcasting {name} with shape {tensor.shape}")
dist.broadcast(tensor, src=0, group=self.weight_update_group)
dist.broadcast(
tensor, src=0, group=self.weight_update_group, async_op=False
)
del tensor
dist.barrier(device_ids=[self.device.index])
current_platform.synchronize()
Expand Down
2 changes: 1 addition & 1 deletion areal/engine/sglang_remote.py
Original file line number Diff line number Diff line change
Expand Up @@ -267,7 +267,7 @@ def update_weights(self, meta: WeightUpdateMeta):
res.raise_for_status()
tik = time.perf_counter()
fut = Future()
if meta.type == "nccl":
if meta.type == current_platform.communication_backend:
fut = self.executor.submit(
update_weights_from_distributed,
meta,
Expand Down
Loading
Loading