-
Notifications
You must be signed in to change notification settings - Fork 4
[recipe] feat: Integrate TransferQueue into RayTrainer #14
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[recipe] feat: Integrate TransferQueue into RayTrainer #14
Conversation
|
||
def _initialize_data_system(self): | ||
num_n_samples = self.config.actor_rollout_ref.rollout.n | ||
# 1. 初始化TransferQueueStorage |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: 注释修改成英文
) | ||
) | ||
log_rollout_meta.reorder(balanced_idx) | ||
self._log_rollout_data(log_rollout_meta, reward_extra_infos_dict, timing_raw, rollout_data_dir) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO: step结束后clear metadata
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR integrates the TransferQueue data system into RayTrainer to enable distributed data management for PPO training. The integration replaces the traditional DataProto-based data handling with BatchMeta objects that serve as metadata references to data stored in the TransferQueue system.
- Adds TransferQueue system initialization and client setup for distributed data storage and management
- Updates data handling throughout the training pipeline to work with BatchMeta metadata objects
- Extends decorator functions to support BatchMeta type checking and operations alongside existing DataProto support
Reviewed Changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 6 comments.
Show a summary per file
File | Description |
---|---|
verl/single_controller/base/decorator.py | Extended type checking and operations to support BatchMeta objects in data processing functions |
recipe/transfer_queue/run_qwen3-8b_transferqueue_npu.sh | Added NPU-specific training script with TransferQueue configuration parameters |
recipe/transfer_queue/ray_trainer.py | New comprehensive RayTrainer implementation with full TransferQueue integration for distributed PPO training |
recipe/transfer_queue/main_ppo.py | Main entry point for PPO training with TransferQueue support and worker initialization |
recipe/transfer_queue/config/transfer_queue_ppo_trainer.yaml | Hydra configuration file extending the base PPO trainer config |
Tip: Customize your code reviews with copilot-instructions.md. Create the file or learn how to get started.
# 1. 初始化TransferQueueStorage | ||
total_storage_size = self.config.data.train_batch_size * self.config.trainer.num_global_batch * num_n_samples | ||
self.data_system_storage_units = {} | ||
storage_placement_group = get_placement_group(self.config.trainer.num_data_storage_units, num_cpus_per_actor=1) | ||
for storage_unit_rank in range(self.config.trainer.num_data_storage_units): | ||
# TransferQueueStorage通过Ray拉起,是一个ray.remote修饰的类 | ||
storage_node = TransferQueueStorageSimpleUnit.options( | ||
placement_group=storage_placement_group, placement_group_bundle_index=storage_unit_rank | ||
).remote(storage_size=math.ceil(total_storage_size / self.config.trainer.num_data_storage_units)) | ||
self.data_system_storage_units[storage_unit_rank] = storage_node | ||
logging.info(f"TransferQueueStorageSimpleUnit #{storage_unit_rank} has been created.") | ||
|
||
# 2. 初始化TransferQueueController | ||
# 这里支持多controller实例以实现负载均衡,支持大规模扩展。不同controller可分配至不同RL计算任务 | ||
self.data_system_controllers = {} | ||
controller_placement_group = get_placement_group(self.config.trainer.num_data_controllers, num_cpus_per_actor=1) | ||
for controller_rank in range(self.config.trainer.num_data_controllers): | ||
self.data_system_controllers[controller_rank] = TransferQueueController.options( | ||
placement_group=controller_placement_group, placement_group_bundle_index=controller_rank | ||
).remote( | ||
num_storage_units=self.config.trainer.num_data_storage_units, | ||
global_batch_size=self.config.data.train_batch_size, | ||
num_global_batch=self.config.trainer.num_global_batch, | ||
num_n_samples=num_n_samples, | ||
) | ||
logging.info(f"TransferQueueController #{controller_rank} has been created.") | ||
|
||
# 3. 将Controller注册至各个Storage | ||
# 每个Storage Unit拿到所有Controller的handler,通过Ray拿到对应的IP+端口,之后建立ZMQ Socket进行消息传输 | ||
self.data_system_controller_infos = process_zmq_server_info(self.data_system_controllers) | ||
self.data_system_storage_unit_infos = process_zmq_server_info(self.data_system_storage_units) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The comments are in Chinese. Consider translating them to English for consistency with the rest of the codebase, which uses English comments.
Copilot uses AI. Check for mistakes.
self.data_system_client = AsyncTransferQueueClient( | ||
client_id="Trainer", | ||
controller_infos=self.data_system_controller_infos[0], | ||
# TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The TODO comment is in Chinese. Consider translating to English: '# TODO: Main controller Client should be aware of all controllers, WorkerGroup and Worker Clients should be aware of one controller'
# TODO: 主控Client感知所有controller,WorkerGroup和Worker的Client感知一个controller | |
# TODO: Main controller Client should be aware of all controllers, WorkerGroup and Worker Clients should be aware of one controller |
Copilot uses AI. Check for mistakes.
# if self.reward_fn is None: | ||
# raise ValueError("A reward_fn is required for REMAX advantage estimation.") | ||
# | ||
# with marked_timer("gen_max", timing_raw, color="purple"): | ||
# gen_baseline_meta = deepcopy(gen_meta) | ||
# gen_baseline_meta.extra_info["do_sample"] = False | ||
# if not self.async_rollout_mode: | ||
# gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta) | ||
# else: | ||
# gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta) | ||
# batch = batch.union(gen_baseline_output) | ||
# reward_baseline_tensor = self.reward_fn(batch) | ||
# reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) | ||
# | ||
# batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) | ||
# | ||
# batch.batch["reward_baselines"] = reward_baseline_tensor | ||
# | ||
# del gen_baseline_batch, gen_baseline_output | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Large commented-out code block should be removed or properly implemented. If this functionality is needed, consider implementing it properly or moving it to a separate issue for future development.
# if self.reward_fn is None: | |
# raise ValueError("A reward_fn is required for REMAX advantage estimation.") | |
# | |
# with marked_timer("gen_max", timing_raw, color="purple"): | |
# gen_baseline_meta = deepcopy(gen_meta) | |
# gen_baseline_meta.extra_info["do_sample"] = False | |
# if not self.async_rollout_mode: | |
# gen_baseline_output = self.actor_rollout_wg.generate_sequences(gen_baseline_meta) | |
# else: | |
# gen_baseline_output = self.async_rollout_manager.generate_sequences(gen_baseline_meta) | |
# batch = batch.union(gen_baseline_output) | |
# reward_baseline_tensor = self.reward_fn(batch) | |
# reward_baseline_tensor = reward_baseline_tensor.sum(dim=-1) | |
# | |
# batch.pop(batch_keys=list(gen_baseline_output.batch.keys())) | |
# | |
# batch.batch["reward_baselines"] = reward_baseline_tensor | |
# | |
# del gen_baseline_batch, gen_baseline_output | |
# (Commented-out block removed) |
Copilot uses AI. Check for mistakes.
# TODO: (transferqueue) (zjj) | ||
if self.config.reward_model.launch_reward_fn_async: | ||
future_reward = compute_reward_async.remote(data=batch, reward_fn=self.reward_fn) | ||
else: | ||
reward_tensor, reward_extra_infos_dict = compute_reward(batch, self.reward_fn) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comment with initials should be either resolved or made more descriptive. Consider creating a proper issue or implementing the missing functionality.
Copilot uses AI. Check for mistakes.
# TODO: (transferqueue) (zjj) | ||
if self.config.algorithm.use_kl_in_reward: | ||
batch, kl_metrics = apply_kl_penalty( | ||
batch, kl_ctrl=self.kl_ctrl_in_reward, kl_penalty=self.config.algorithm.kl_penalty | ||
) | ||
metrics.update(kl_metrics) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comment with initials should be either resolved or made more descriptive. Consider creating a proper issue or implementing the missing functionality.
Copilot uses AI. Check for mistakes.
# TODO: (transferqueue) (zjj) | ||
batch = compute_advantage( | ||
batch, | ||
adv_estimator=self.config.algorithm.adv_estimator, | ||
gamma=self.config.algorithm.gamma, | ||
lam=self.config.algorithm.lam, | ||
num_repeat=self.config.actor_rollout_ref.rollout.n, | ||
norm_adv_by_std_in_grpo=norm_adv_by_std_in_grpo, | ||
config=self.config.algorithm, | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
TODO comment with initials should be either resolved or made more descriptive. Consider creating a proper issue or implementing the missing functionality.
Copilot uses AI. Check for mistakes.
acbd595
into
TransferQueue:main_tq_submodule
What does this PR do?
Checklist Before Starting
[{modules}] {type}: {description}
(This will be checked by the CI){modules}
includefsdp
,megatron
,sglang
,vllm
,rollout
,trainer
,ci
,training_utils
,recipe
,hardware
,deployment
,ray
,worker
,single_controller
,misc
,perf
,model
,algo
,env
,tool
,ckpt
,doc
,data
,
like[megatron, fsdp, doc]
{type}
is infeat
,fix
,refactor
,chore
,test
[BREAKING]
to the beginning of the title.[BREAKING][fsdp, megatron] feat: dynamic batching
Test
API and Usage Example
# Add code snippet or script demonstrating how to use this
Design & Code Changes
Checklist Before Submitting
Important
Please check all the following items before requesting a review, otherwise the reviewer might deprioritize this PR for review.
pre-commit install && pre-commit run --all-files --show-diff-on-failure --color=always
ci-request
channel in theverl
Slack workspace. (If not accessible, please try the Feishu group (飞书群).)