Skip to content

Commit da4222b

Browse files
committed
feat: introduce vLLM AsyncLLM to support multi-turn rollout
1 parent 8719371 commit da4222b

File tree

20 files changed

+871
-50
lines changed

20 files changed

+871
-50
lines changed
Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,72 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
import asyncio
15+
from typing import Any, Dict, List
16+
17+
from omegaconf import DictConfig
18+
from openai.types.chat.chat_completion import ChatCompletion
19+
20+
from verl.protocol import DataProto
21+
from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
22+
23+
24+
class NaiveChatCompletionScheduler(ChatCompletionScheduler):
25+
26+
def __init__(self, config: DictConfig, model_path: str, server_addresses: List[str], max_cache_size: int = 10000):
27+
super().__init__(config, model_path, server_addresses, max_cache_size)
28+
29+
async def generate_sequences(self, prompts: DataProto, **sampling_params) -> DataProto:
30+
kwargs = dict(
31+
n=self.config.n,
32+
max_completion_tokens=self.config.response_length,
33+
temperature=self.config.temperature,
34+
top_p=self.config.top_p,
35+
)
36+
37+
do_sample = prompts.meta_info.get('do_sample', True)
38+
is_validate = prompts.meta_info.get('validate', False)
39+
if not do_sample or is_validate:
40+
kwargs["n"] = 1
41+
kwargs["temperature"] = 0
42+
43+
kwargs.update(sampling_params)
44+
print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")
45+
46+
async def callback(completions: ChatCompletion, info: Dict[str, Any]):
47+
info["all_completions"][info["index"]] = completions
48+
49+
# NOTE: we can call tools and resubmit chat completions here.
50+
# call_tools(completions, info)
51+
# await self.submit_chat_completions(callback2, ...)
52+
53+
tasks, all_completions = [], [None] * len(prompts)
54+
for i, prompt in enumerate(prompts.non_tensor_batch["raw_prompt"]):
55+
# raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
56+
tasks.append(
57+
asyncio.create_task(
58+
self.submit_chat_completions(
59+
callback=callback,
60+
callback_additional_info={
61+
"all_completions": all_completions,
62+
"index": i
63+
},
64+
model=self.model_name,
65+
messages=prompt,
66+
**kwargs,
67+
)))
68+
await asyncio.gather(*tasks)
69+
70+
print("[NaiveChatCompletionScheduler] generate_sequences done")
71+
# TODO: completions => DataProto
72+
return all_completions

examples/ppo_trainer/run_qwen2-7b_seq_balance.sh

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet
88
train_files="['$gsm8k_train_path', '$math_train_path']"
99
test_files="['$gsm8k_test_path', '$math_test_path']"
1010

11+
# For async rollout mode, dataset should return raw chat.
12+
rollout_mode="sync"
13+
if [ "$rollout_mode" = "async" ]; then
14+
return_raw_chat="True"
15+
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
16+
fi
17+
1118
python3 -m verl.trainer.main_ppo \
1219
algorithm.adv_estimator=gae \
1320
data.train_files="$train_files" \
1421
data.val_files="$test_files" \
22+
data.return_raw_chat=$return_raw_chat \
1523
data.train_batch_size=4096 \
1624
data.max_prompt_length=4096 \
1725
data.max_response_length=4096 \
@@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \
2937
actor_rollout_ref.actor.use_kl_loss=False \
3038
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
3139
actor_rollout_ref.rollout.name=vllm \
40+
actor_rollout_ref.rollout.mode=$rollout_mode \
41+
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
3242
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
3343
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
3444
critic.optim.lr=1e-5 \

recipe/dapo/src/dapo_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer):
4040
Note that this trainer runs on the driver process on a single CPU/GPU node.
4141
"""
4242

43-
def fit(self):
43+
async def fit(self):
4444
"""
4545
The training loop of PPO.
4646
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

recipe/dapo/src/main_dapo.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,8 @@ def run_ppo(config) -> None:
7575

7676
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
7777
class TaskRunner:
78-
def run(self, config):
78+
79+
async def run(self, config):
7980
# print initial config
8081
from pprint import pprint
8182

@@ -200,7 +201,7 @@ def run(self, config):
200201
val_reward_fn=val_reward_fn,
201202
)
202203
trainer.init_workers()
203-
trainer.fit()
204+
await trainer.fit()
204205

205206

206207
if __name__ == "__main__":

recipe/prime/main_prime.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,8 @@
2929
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
3030
"""
3131

32+
import asyncio
33+
3234
import hydra
3335
import ray
3436

@@ -52,6 +54,10 @@ def run_prime(config, compute_score=None):
5254

5355
@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
5456
def main_task(config, compute_score=None):
57+
asyncio.run(_main_task(config, compute_score))
58+
59+
60+
async def _main_task(config, compute_score=None):
5561
# print initial config
5662
from pprint import pprint
5763

@@ -141,7 +147,7 @@ def main_task(config, compute_score=None):
141147
val_reward_fn=val_reward_fn,
142148
)
143149
trainer.init_workers()
144-
trainer.fit()
150+
await trainer.fit()
145151

146152

147153
if __name__ == "__main__":

recipe/prime/prime_ray_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def _load_checkpoint(self):
331331
if isinstance(self.train_dataloader.dataset, RLHFDataset):
332332
self.train_dataloader.dataset.resume_dataset_state()
333333

334-
def fit(self):
334+
async def fit(self):
335335
"""
336336
The training loop of PPO.
337337
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.

tests/rollout/test_vllm_multi_turn.py

Lines changed: 149 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,149 @@
1+
# Copyright 2024 Bytedance Ltd. and/or its affiliates
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
import asyncio
16+
from typing import Any, Dict
17+
18+
import ray
19+
from omegaconf import OmegaConf
20+
from openai.types.chat.chat_completion import ChatCompletion
21+
22+
from verl.single_controller.ray import RayClassWithInitArgs, RayWorkerGroup
23+
from verl.single_controller.ray.base import Worker, create_colocated_worker_cls
24+
from verl.trainer.ppo.ray_trainer import ResourcePoolManager, Role
25+
from verl.workers.fsdp_async_workers import AsyncActorRolloutRefWorker, AsyncLLMManager
26+
from verl.workers.rollout.chat_scheduler import ChatCompletionScheduler
27+
28+
29+
async def test_vllm_multi_turn():
30+
config = OmegaConf.load("verl/trainer/config/ppo_trainer.yaml")
31+
model_path = "Qwen/Qwen2-7B-Instruct"
32+
model_name = "/".join(model_path.split("/")[-2:])
33+
config.actor_rollout_ref.model.path = model_path
34+
config.actor_rollout_ref.rollout.mode = "async"
35+
config.actor_rollout_ref.rollout.prompt_length = 4096
36+
config.actor_rollout_ref.rollout.response_length = 4096
37+
38+
# =========================== 1. Create hybrid ActorRollout workers ===========================
39+
ray.init(
40+
runtime_env={
41+
'env_vars': {
42+
'TOKENIZERS_PARALLELISM': 'true',
43+
'NCCL_DEBUG': 'WARN',
44+
'VLLM_LOGGING_LEVEL': 'WARN',
45+
'VLLM_USE_V1': '1',
46+
}
47+
})
48+
role_worker_mapping = {
49+
Role.ActorRollout: ray.remote(AsyncActorRolloutRefWorker),
50+
}
51+
global_pool_id = 'global_pool'
52+
resource_pool_spec = {
53+
global_pool_id: [config.trainer.n_gpus_per_node] * config.trainer.nnodes,
54+
}
55+
mapping = {
56+
Role.ActorRollout: global_pool_id,
57+
}
58+
resource_pool_manager = ResourcePoolManager(resource_pool_spec=resource_pool_spec, mapping=mapping)
59+
resource_pool_manager.create_resource_pool()
60+
resource_pool_to_cls = {pool: {} for pool in resource_pool_manager.resource_pool_dict.values()}
61+
62+
# create actor and rollout
63+
resource_pool = resource_pool_manager.get_resource_pool(Role.ActorRollout)
64+
actor_rollout_cls = RayClassWithInitArgs(cls=role_worker_mapping[Role.ActorRollout],
65+
config=config.actor_rollout_ref,
66+
role='actor_rollout')
67+
resource_pool_to_cls[resource_pool]['actor_rollout'] = actor_rollout_cls
68+
69+
all_wg = {}
70+
wg_dicts = []
71+
for resource_pool, class_dict in resource_pool_to_cls.items():
72+
worker_dict_cls = create_colocated_worker_cls(class_dict=class_dict, worker_cls=Worker)
73+
wg_dict = RayWorkerGroup(resource_pool=resource_pool, ray_cls_with_init=worker_dict_cls)
74+
spawn_wg = wg_dict.spawn(prefix_set=class_dict.keys())
75+
all_wg.update(spawn_wg)
76+
wg_dicts.append(wg_dict)
77+
actor_rollout_wg = all_wg['actor_rollout']
78+
actor_rollout_wg.init_model()
79+
80+
# =========================== 2. Create AsyncLLMManager&ChatScheduler ===========================
81+
async_rollout_manager = AsyncLLMManager(
82+
config=config.actor_rollout_ref,
83+
worker_group=actor_rollout_wg,
84+
)
85+
86+
async_chat_scheduler = ChatCompletionScheduler(
87+
config=config.actor_rollout_ref.rollout,
88+
model_path=config.actor_rollout_ref.model.path,
89+
server_addresses=async_rollout_manager.server_addresses,
90+
)
91+
92+
# =========================== 3. Multi turn rollout ===========================
93+
async def callback(completions: ChatCompletion, info: Dict[str, Any]):
94+
messages, round = info["messages"], info["round"]
95+
message = completions.choices[0].message
96+
messages.append({"role": message.role, "content": message.content})
97+
print(f"[round={round}] role: {message.role}, content: {message.content}")
98+
99+
extra_headers = {"x-request-id": completions.id}
100+
if round == 0:
101+
messages.append({"role": "user", "content": "What is your name?"})
102+
await async_chat_scheduler.submit_chat_completions(
103+
callback=callback,
104+
callback_additional_info={
105+
"messages": messages,
106+
"round": 1
107+
},
108+
model=model_name,
109+
messages=messages,
110+
extra_headers=extra_headers,
111+
)
112+
elif round == 1:
113+
messages.append({"role": "user", "content": "What is your favorite color?"})
114+
await async_chat_scheduler.submit_chat_completions(
115+
callback=callback,
116+
callback_additional_info={
117+
"messages": messages,
118+
"round": 2
119+
},
120+
model=model_name,
121+
messages=messages,
122+
extra_headers=extra_headers,
123+
)
124+
else:
125+
print("Done!")
126+
127+
messages = [{
128+
"role": "user",
129+
"content": "Let's play a role playing game. Your name is Bob, your favorite color is red."
130+
}]
131+
await async_chat_scheduler.submit_chat_completions(
132+
callback=callback,
133+
callback_additional_info={
134+
"messages": messages,
135+
"round": 0
136+
},
137+
model=model_name,
138+
messages=messages,
139+
)
140+
assert len(messages) == 6
141+
for round, message in enumerate(messages):
142+
if round % 2 == 0:
143+
assert message["role"] == "user"
144+
else:
145+
assert message["role"] == "assistant"
146+
147+
148+
if __name__ == "__main__":
149+
asyncio.run(test_vllm_multi_turn())

verl/single_controller/base/decorator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ class Dispatch(Enum):
3737
DP_COMPUTE_PROTO_WITH_FUNC = 10
3838
DP_COMPUTE_METRIC = 11
3939

40+
# This is a special dispatch mode for vllm ExternalRayDistributedExecutor
41+
DIRECT_ROLLOUT_METHOD = 12
42+
4043

4144
class Execute(Enum):
4245
ALL = 0
@@ -65,6 +68,10 @@ def dispatch_one_to_all(worker_group, *args, **kwargs):
6568
return args, kwargs
6669

6770

71+
def dummy_direct_rollout_call(worker_group, *args, **kwargs):
72+
raise NotImplementedError("Direct rollout call is forbidden.")
73+
74+
6875
def dispatch_all_to_all(worker_group, *args, **kwargs):
6976
return args, kwargs
7077

@@ -356,6 +363,10 @@ def get_predefined_dispatch_fn(dispatch_mode):
356363
"collect_fn": collect_dp_compute_data_proto,
357364
},
358365
Dispatch.DP_COMPUTE_METRIC: {"dispatch_fn": dispatch_dp_compute_data_proto, "collect_fn": collect_dp_compute},
366+
Dispatch.DIRECT_ROLLOUT_METHOD: {
367+
"dispatch_fn": dummy_direct_rollout_call,
368+
"collect_fn": dummy_direct_rollout_call,
369+
},
359370
}
360371
return predefined_dispatch_mode_fn[dispatch_mode]
361372

verl/single_controller/base/register_center/ray.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,27 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from typing import Dict, Tuple
16+
1517
import ray
1618

1719

1820
@ray.remote
1921
class WorkerGroupRegisterCenter:
2022
def __init__(self, rank_zero_info):
2123
self.rank_zero_info = rank_zero_info
24+
# rank -> node_id
25+
self.workers_info: Dict[int, str] = {}
2226

2327
def get_rank_zero_info(self):
2428
return self.rank_zero_info
2529

30+
def set_worker_info(self, rank, node_id) -> None:
31+
self.workers_info[rank] = node_id
32+
33+
def get_worker_info(self) -> Dict[int, str]:
34+
return self.workers_info
35+
2636

2737
def create_worker_group_register_center(name, info):
2838
return WorkerGroupRegisterCenter.options(name=name).remote(info)

verl/single_controller/base/worker.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
import socket
2020
from dataclasses import dataclass
2121

22+
import ray
23+
2224
from .decorator import Dispatch, Execute, register
2325

2426

@@ -125,6 +127,11 @@ def _configure_before_init(self, register_center_name: str, rank: int):
125127
)
126128

127129
os.environ.update(rank_zero_info)
130+
else:
131+
self.register_center = ray.get_actor(register_center_name)
132+
133+
# set worker info for node affinity scheduling
134+
ray.get(self.register_center.set_worker_info.remote(rank, ray.get_runtime_context().get_node_id()))
128135

129136
def __init__(self, cuda_visible_devices=None) -> None:
130137
# construct a meta from envrionment variable. Note that the import must be inside the class because it is executed remotely

0 commit comments

Comments
 (0)