Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion .github/workflows/vllm.yml
Original file line number Diff line number Diff line change
Expand Up @@ -84,4 +84,8 @@ jobs:
cd tests/generation
export OUTPUT_PATH="${HOME}/data/gen/qwen_05_gen_test.parquet"
MODEL_ID=Qwen/Qwen2.5-0.5B-Instruct NGPUS_PER_NODE=1 GEN_TP=1 bash ./run_gen_qwen05.sh
rm -rf "${OUTPUT_PATH}"
rm -rf "${OUTPUT_PATH}"
- name: Running multi-turn rollout tests on 8 L20 GPUs
run: |
pip3 install --upgrade vllm==0.8.3
python3 tests/rollout/test_vllm_multi_turn.py
10 changes: 10 additions & 0 deletions examples/grpo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,18 @@ set -x
# If you are using vllm<=0.6.3, you might need to set the following environment variable to avoid bugs:
# export VLLM_ATTENTION_BACKEND=XFORMERS

# For async rollout mode, dataset should return raw chat.
rollout_mode="sync"
if [ "$rollout_mode" = "async" ]; then
return_raw_chat="True"
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=grpo \
data.train_files=$HOME/data/gsm8k/train.parquet \
data.val_files=$HOME/data/gsm8k/test.parquet \
data.return_raw_chat=$return_raw_chat \
data.train_batch_size=1024 \
data.max_prompt_length=512 \
data.max_response_length=1024 \
Expand All @@ -27,6 +35,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.fsdp_config.optimizer_offload=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
actor_rollout_ref.rollout.gpu_memory_utilization=0.6 \
actor_rollout_ref.rollout.n=5 \
actor_rollout_ref.ref.fsdp_config.param_offload=True \
Expand Down
151 changes: 151 additions & 0 deletions examples/ppo_trainer/naive_chat_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
# Copyright 2024 Bytedance Ltd. and/or its affiliates
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import asyncio
from typing import Any, Dict, List

import torch
from omegaconf import DictConfig
from openai.types.chat.chat_completion import ChatCompletion
from tensordict import TensorDict

from verl.protocol import DataProto
from verl.workers.rollout.async_server import ChatCompletionScheduler


class NaiveChatCompletionScheduler(ChatCompletionScheduler):
"""
A very naive implementation of ChatCompletionScheduler for demo purpose,
only do single-turn chat completion.
"""

def __init__(
self,
config: DictConfig,
model_path: str,
server_addresses: List[str],
max_cache_size: int = 10000,
):
super().__init__(config, model_path, server_addresses, max_cache_size)

async def generate_sequences(self, batch: DataProto, **sampling_params) -> DataProto:
kwargs = dict(
n=self.config.n,
max_completion_tokens=self.config.response_length,
temperature=self.config.temperature,
top_p=self.config.top_p,
)

do_sample = batch.meta_info.get("do_sample", True)
is_validate = batch.meta_info.get("validate", False)
if not do_sample or is_validate:
kwargs["n"] = 1
kwargs["temperature"] = 0

kwargs.update(sampling_params)
print(f"[NaiveChatCompletionScheduler] generate_sequences sampling params: {kwargs}")

async def callback(completions: ChatCompletion, info: Dict[str, Any], exception: Exception):
conversation, batch_conversations, batch_index = (
info["conversation"],
info["batch_conversations"],
info["batch_index"],
)

conversations = []
for choice in completions.choices:
chat = conversation.copy()
chat.append({"role": choice.message.role, "content": choice.message.content})
conversations.append(chat)
batch_conversations[batch_index] = conversations

# NOTE: we can call tools and resubmit chat completions here.
# call_tools(completions, info)
# await self.submit_chat_completions(callback2, ...)

tasks, batch_conversations = [], [None] * len(batch)
for batch_index, conversation in enumerate(batch.non_tensor_batch["raw_prompt"]):
# raw_prompt: [{"role": "user", "content": ""}, ["role": "assistant", "content"], ...]
tasks.append(
asyncio.create_task(
self.submit_chat_completions(
callback=callback,
callback_additional_info={
"batch_conversations": batch_conversations,
"batch_index": batch_index,
"conversation": list(conversation),
},
model=self.model_name,
messages=conversation,
**kwargs,
)
)
)
await asyncio.gather(*tasks)
print("[NaiveChatCompletionScheduler] generate_sequences done")

return self._postprocess(batch, batch_conversations, kwargs["n"])

def _postprocess(
self, batch: DataProto, batch_conversations: List[List[List[Dict[str, str]]]], n: int
) -> DataProto:
# NOTE: consistent with batch version of generate_sequences in vllm_rollout_spmd.py
# prompts: left pad
# responses: right pad
# input_ids: prompt + response
# attention_mask: [0,0,0,0,1,1,1,1, | 1,1,1,0,0,0,0,0]
# position_ids: [0,0,0,0,0,1,2,3, | 4,5,6,7,8,9,10,11]

# prompts: [prompt] from input dataset
prompts = [
self.tokenizer.apply_chat_template(prompt, add_generation_prompt=True, tokenize=False)
for prompt in batch.non_tensor_batch["raw_prompt"]
]

# flatten batch_conversations if n > 1
assert len(batch_conversations) == len(prompts)
batch_conversations = [conversation for conversations in batch_conversations for conversation in conversations]
assert len(batch_conversations) == len(prompts) * n

# sequences: [prompt + response]
sequences = [
self.tokenizer.apply_chat_template(conversation, add_generation_prompt=False, tokenize=False)
for conversation in batch_conversations
]

# responses: [response]
# TODO: mask out tools calling tokens?
responses = [sequence[len(prompts[i // n]) :] for i, sequence in enumerate(sequences)]

prompts = self.tokenizer(prompts, return_tensors="pt", padding="longest", padding_side="left")
responses = self.tokenizer(responses, return_tensors="pt", padding="longest", padding_side="right")
if n > 1:
prompts["input_ids"] = prompts["input_ids"].repeat_interleave(n, dim=0)
prompts["attention_mask"] = prompts["attention_mask"].repeat_interleave(n, dim=0)

input_ids = torch.cat([prompts["input_ids"], responses["input_ids"]], dim=1)
attention_mask = torch.cat([prompts["attention_mask"], responses["attention_mask"]], dim=1)
position_ids = (attention_mask.cumsum(dim=1) - 1) * attention_mask

batch = TensorDict(
{
"prompts": prompts["input_ids"],
"responses": responses["input_ids"],
"input_ids": input_ids,
"attention_mask": attention_mask,
"position_ids": position_ids,
},
batch_size=len(input_ids),
)

return DataProto(batch=batch)
10 changes: 10 additions & 0 deletions examples/ppo_trainer/run_qwen2-7b_seq_balance.sh
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,18 @@ math_test_path=$HOME/data/math/test.parquet
train_files="['$gsm8k_train_path', '$math_train_path']"
test_files="['$gsm8k_test_path', '$math_test_path']"

# For async rollout mode, dataset should return raw chat.
rollout_mode="sync"
if [ "$rollout_mode" = "async" ]; then
return_raw_chat="True"
chat_scheduler=examples.ppo_trainer.naive_chat_scheduler.NaiveChatCompletionScheduler
fi

python3 -m verl.trainer.main_ppo \
algorithm.adv_estimator=gae \
data.train_files="$train_files" \
data.val_files="$test_files" \
data.return_raw_chat=$return_raw_chat \
data.train_batch_size=4096 \
data.max_prompt_length=4096 \
data.max_response_length=4096 \
Expand All @@ -29,6 +37,8 @@ python3 -m verl.trainer.main_ppo \
actor_rollout_ref.actor.use_kl_loss=False \
actor_rollout_ref.rollout.tensor_model_parallel_size=2 \
actor_rollout_ref.rollout.name=vllm \
actor_rollout_ref.rollout.mode=$rollout_mode \
actor_rollout_ref.rollout.chat_scheduler=$chat_scheduler \
actor_rollout_ref.rollout.gpu_memory_utilization=0.5 \
actor_rollout_ref.rollout.log_prob_max_token_len_per_gpu=24000 \
critic.optim.lr=1e-5 \
Expand Down
2 changes: 1 addition & 1 deletion recipe/dapo/src/dapo_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class RayDAPOTrainer(RayPPOTrainer):
Note that this trainer runs on the driver process on a single CPU/GPU node.
"""

def fit(self):
async def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC
Expand Down
5 changes: 3 additions & 2 deletions recipe/dapo/src/main_dapo.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,8 @@ def run_ppo(config) -> None:

@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
class TaskRunner:
def run(self, config):

async def run(self, config):
# print initial config
from pprint import pprint

Expand Down Expand Up @@ -201,7 +202,7 @@ def run(self, config):
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
await trainer.fit()


if __name__ == "__main__":
Expand Down
8 changes: 7 additions & 1 deletion recipe/prime/main_prime.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import asyncio

import hydra
import ray

Expand All @@ -53,6 +55,10 @@ def run_prime(config, compute_score=None):

@ray.remote(num_cpus=1) # please make sure main_task is not scheduled on head
def main_task(config, compute_score=None):
asyncio.run(_main_task(config, compute_score))


async def _main_task(config, compute_score=None):
# print initial config
from pprint import pprint

Expand Down Expand Up @@ -142,7 +148,7 @@ def main_task(config, compute_score=None):
val_reward_fn=val_reward_fn,
)
trainer.init_workers()
trainer.fit()
await trainer.fit()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion recipe/prime/prime_ray_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -331,7 +331,7 @@ def _load_checkpoint(self):
if isinstance(self.train_dataloader.dataset, RLHFDataset):
self.train_dataloader.dataset.resume_dataset_state()

def fit(self):
async def fit(self):
"""
The training loop of PPO.
The driver process only need to call the compute functions of the worker group through RPC to construct the PPO dataflow.
Expand Down
4 changes: 3 additions & 1 deletion tests/ray/test_worker_group_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,9 @@ def foo_custom(self, x, y):
@ray.remote(num_gpus=0.1)
def remote_call_wg(worker_names):
class_with_args = RayClassWithInitArgs(cls=TestActor, x=2)
worker_group = RayWorkerGroup.from_detached(worker_names=worker_names, ray_cls_with_init=class_with_args)
worker_group = RayWorkerGroup.from_detached(
worker_names=worker_names, ray_cls_with_init=class_with_args, name_prefix=None
)
print(worker_group.worker_names)

output_ref = worker_group.foo_custom(x=[1, 2], y=[5, 6])
Expand Down
Loading
Loading