Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,8 @@ jobs:
cd tests/generation
export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet"
MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh
rm -rf "${OUTPUT_PATH}"
rm -rf "${OUTPUT_PATH}"
- name: Running multi-turn rollout tests on 8 L20 GPUs
run: |
pip3 install --upgrade vllm==0.8.3 tensordict==0.7.2
python3 tests/rollout/test_vllm_multi_turn.py
10 changes: 10 additions & 0 deletions examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ set -x
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS

# For async rollout mode, dataset should return raw chat.
rollout_mode="sync"
if [ "$rollout_mode" = "async" ]; then
return_raw_chat="True"
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.return_raw_chat=$return_raw_chat \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
Expand All @@ -27,6 +35,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
Expand Down
151 changes: 151 additions & 0 deletions examples/ppo_trainer/naive_chat_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any, Dict, List

import torch
from omegaconf import DictConfig
from openai.types.chat.chat_completion import ChatCompletion
from tensordict import TensorDict

from verl.protocol import DataProto
from verl.workers.rollout.async_server import ChatCompletionScheduler


class NaiveChatCompletionScheduler(ChatCompletionScheduler):
"""
A very naive implementation of ChatCompletionScheduler for demo purpose,
only do single-turn chat completion.
"""

def __init__(
self,
config: DictConfig,
model_path: str,
server_addresses: List[str],
max_cache_size: int = 10000,
):
super().__init__(config, model_path, server_addresses, max_cache_size)

async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto:
kwargs = dict(
n=self.config.n,
max_completion_tokens=self.config.response_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)

do_sample = batch.meta_info.get("do_sample", True)
is_validate = batch.meta_info.get("validate", False)
if not do_sample or is_validate:
kwargs["n"] = 1
kwargs["temperature"] = 0

kwargs.update(sampling_params)
print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")

async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
conversation, batch_conversations, batch_index = (
info["conversation"],
info["batch_conversations"],
info["batch_index"],
)

conversations = []
for choice in completions.choices:
chat = conversation.copy()
chat.append({"role": choice.message.role, "content": choice.message.content})
conversations.append(chat)
batch_conversations[batch_index] = conversations

# NOTE: we can call tools and resubmit chat completions here.
# call_tools(completions, info)
# await self.submit_chat_completions(callback2, ...)

tasks, batch_conversations = [], [None] * len(batch)
for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]):
# raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
tasks.append(
asyncio.create_task(
self.submit_chat_completions(
callback=callback,
callback_additional_info={
"batch_conversations": batch_conversations,
"batch_index": batch_index,
"conversation": list(conversation),
},
model=self.model_name,
messages=conversation,
**kwargs,
)
)
)
await asyncio.gather(*tasks)
print("[NaiveChatCompletionScheduler] generate_sequences done")

return self._postprocess(batch, batch_conversations, kwargs["n"])

def _postprocess(
self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int
) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

# prompts: [prompt] from input dataset
prompts = [
self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
for prompt in batch.non_tensor_batch["raw_prompt"]
]

# flatten batch_conversations if n > 1
assert len(batch_conversations) == len(prompts)
batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations]
assert len(batch_conversations) == len(prompts) * n

# sequences: [prompt + response]
sequences = [
self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
for conversation in batch_conversations
]

# responses: [response]
# TODO: mask out tools calling tokens?
responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]

prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left")
responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right")
if n > 1:
prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0)
prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0)

input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1)
attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask

batch = TensorDict(
{
"prompts": prompts["input_ids"],
"responses": responses["input_ids"],
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
},
batch_size=len(input_ids),
)

return DataProto(batch=batch)
10 changes: 10 additions & 0 deletions examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

# For async rollout mode, dataset should return raw chat.
rollout_mode="sync"
if [ "$rollout_mode" = "async" ]; then
return_raw_chat="True"
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=$return_raw_chat \
data.train_batch_size=4096 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
Expand All @@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
critic.optim.lr=1e-5 \
Expand Down
2 changes: 1 addition & 1 deletion recipe/dapo/src/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer):
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""

def fit(self):
async def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
Expand Down
5 changes: 3 additions & 2 deletions recipe/dapo/src/main_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def run_ppo(config) -> None:

@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):

async def run(self, config):
# print initial config
from pprint import pprint

Expand Down Expand Up @@ -201,7 +202,7 @@ def run(self, config):
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
await trainer.fit()


if __name__ == "__main__":
Expand Down
8 changes: 7 additions & 1 deletion recipe/prime/main_prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import asyncio

import hydra
import ray

Expand All @@ -53,6 +55,10 @@ def run_prime(config, compute_score=None):

@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
asyncio.run(_main_task(config, compute_score))


async def _main_task(config, compute_score=None):
# print initial config
from pprint import pprint

Expand Down Expand Up @@ -142,7 +148,7 @@ def main_task(config, compute_score=None):
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
await trainer.fit()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _load_checkpoint(self):
if isinstance(self.train_dataloader.dataset, RLHFDataset):
self.train_dataloader.dataset.resume_dataset_state()

def fit(self):
async def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
Expand Down
4 changes: 3 additions & 1 deletion tests/ray/test_worker_group_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def foo_custom(self, x, y):
@ray.remote(num_gpus=0.1)
def remote_call_wg(worker_names):
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args)
worker_group = RayWorkerGroup.from_detached(
worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None
)
print(worker_group.worker_names)

output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])
Expand Down
Loading