-
Notifications
You must be signed in to change notification settings - Fork 102
[Trainer/Algorithm] Support registering custom policy loss functions + refactor adv estimator registry to allow registration outside ray workers #126
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Summary of Changes
Hello @erictang000, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!
This pull request significantly enhances the extensibility of the training framework by introducing a robust mechanism for registering custom policy loss functions. It involves refactoring the existing PPO policy loss into a standalone function and implementing a new registry system to manage these functions. Additionally, a configuration parameter has been renamed for improved clarity and consistency.
Highlights
- Extensibility for Policy Losses: Introduced a
PolicyLossRegistry
and a@register_policy_loss
decorator, enabling users to easily define and register their own custom policy loss functions. This allows for greater flexibility in experimenting with different PPO loss formulations. - Policy Loss Refactoring: The core PPO policy loss logic has been refactored from an
nn.Module
class (PolicyLoss
) into a standalone function (ppo_policy_loss
) withinppo_utils.py
. This improves modularity, simplifies the loss calculation, and facilitates the new custom registration mechanism. - Configuration Parameter Renaming: The configuration parameter
trainer.algorithm.ppo_loss_type
has been renamed totrainer.algorithm.policy_loss_type
across documentation, default configurations, and example scripts. This change enhances clarity and consistency with the new extensibility features. - New Example for Custom Losses: A new example (
examples/algorithm/custom_policy_loss
) has been added to demonstrate how to register and utilize a custom policy loss function, providing a clear guide for developers.
Using Gemini Code Assist
The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.
Invoking Gemini
You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command>
or @gemini-code-assist <command>
. Below is a summary of the supported commands.
Feature | Command | Description |
---|---|---|
Code Review | /gemini review |
Performs a code review for the current pull request in its current state. |
Pull Request Summary | /gemini summary |
Provides a summary of the current pull request in its current state. |
Comment | @gemini-code-assist | Responds in comments when explicitly tagged, both in issue comments and review comments. |
Help | /gemini help |
Displays a list of available commands. |
Customization
To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/
folder in the base of the repository. Detailed instructions can be found here.
Limitations & Feedback
Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments or fill out our survey to provide feedback.
You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.
Footnotes
-
Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution. ↩
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a registry for custom policy loss functions, refactoring the existing PPO loss logic into a more extensible, function-based approach. This is a great improvement, following the pattern already established for advantage estimators. The changes are well-implemented across the codebase, including updates to documentation, examples, and tests.
I've identified one critical issue in the new example code that would cause a runtime error due to an incorrect function signature. I also found a couple of medium-severity issues in the documentation and an example script that could be confusing to users.
Overall, this is a valuable contribution that enhances the framework's flexibility. Addressing the feedback will ensure the new feature is robust and easy to use.
skyrl-train/examples/algorithm/custom_policy_loss/main_custom_policy_loss.py
Outdated
Show resolved
Hide resolved
skyrl-train/examples/algorithm/custom_policy_loss/run_custom_policy_loss.sh
Outdated
Show resolved
Hide resolved
loss_type=self.cfg.trainer.algorithm.ppo_loss_type, | ||
loss_reduction=self.cfg.trainer.algorithm.loss_reduction, | ||
) | ||
policy_loss_func = PolicyLossRegistry.get(self.cfg.trainer.algorithm.policy_loss_type) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is broken right now because we have 2 levels of ray worker nesting here :( so registering in the entrypoint doesn't propagate to here...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The simplest way to fix this is probably to make the registry a named actor and fetch it inside the training worker. Another option would be to write stuff directly to the ray object store.
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a significant and well-executed refactoring to support custom policy loss functions and improve the registration mechanism for both policy losses and advantage estimators. The core of the change is the new BaseFunctionRegistry
which cleverly uses a named Ray actor to allow function registration from the main process, making them available to all Ray workers. This is a great improvement for extensibility.
The changes are consistent, with updates to documentation, examples, and configuration to reflect the new API. The new tests, especially those for cross-process registration via Ray, are thorough and provide good coverage for the new functionality.
My review includes a few suggestions for improving code clarity and robustness, and points out a minor bug in one of the new tests. Overall, this is a high-quality contribution.
skyrl-train/examples/algorithm/custom_policy_loss/run_custom_policy_loss.sh
Outdated
Show resolved
Hide resolved
skyrl-train/examples/algorithm/custom_policy_loss/main_custom_policy_loss.py
Show resolved
Hide resolved
"regular", | ||
"dual_clip", | ||
), f"invalid ppo_loss_type: {cfg.trainer.algorithm.ppo_loss_type}. Must be one of `['regular', 'dual_clip']`" | ||
assert ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this validation means that users have to register their functions outside of the ray entrypoint in their scripts like:
# Register the custom policy loss
PolicyLossRegistry.register("reinforce", compute_reinforce_policy_loss)
@ray.remote(num_cpus=1)
def skyrl_entrypoint(cfg: DictConfig):
exp = BasePPOExp(cfg)
exp.run()
since config validation happens outside of the ray entrypoint as well. I think this is the natural way people would be inclined to do just noting it here.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left a few small comments, but overall this approach looks good to me and gives a clean API for researchers.
def __str__(self): | ||
return self.value |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think you can avoid this by subclassing StrEnum
def register_policy_loss(name: Union[str, PolicyLossType]): | ||
"""Decorator to register a policy loss function.""" | ||
|
||
def decorator(func: Callable): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should use https://docs.python.org/3/library/functools.html#functools.wraps here
return loss, clip_ratio | ||
|
||
|
||
def reduce_loss(loss: torch.Tensor, loss_mask: Optional[torch.Tensor], loss_reduction: str) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we make loss_reduction a literal type for static checking here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
cls._synced_to_actor = False | ||
|
||
@classmethod | ||
def _get_or_create_actor(cls): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This should probably raise if Ray isn't initialized
@classmethod | ||
def _sync_local_to_actor(cls): | ||
"""Sync all local functions to Ray actor (one-time when Ray becomes available).""" | ||
if cls._synced_to_actor or not ray.is_initialized(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
again, I think you might want to raise here if ray isn't initialized, otherwise the training could fail later when we try to get a function from the registry
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good overall. There's some implicit behaviour around how registry access works with and without the presence of the named actor, and it would be good to atleast document.
if name in cls._functions: | ||
raise ValueError(f"{cls._function_type} '{name}' already registered") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
could there be a case where the same function is being registered multiple times because of .register
line being run in the driver + a worker process?
Currently, I think not, because the worker processes will have a fresh import of ppo_utils.py
(like say a FSDPWorker, or even skyrl_entrypoint
) and thus the registry state is empty, and the .register
code will only run on the driver.
return loss, clip_ratio | ||
|
||
|
||
def reduce_loss(loss: torch.Tensor, loss_mask: Optional[torch.Tensor], loss_reduction: str) -> torch.Tensor: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
+1
def register(cls, name: Union[str, Enum], func: Callable): | ||
"""Register a function.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please add a detailed docstring for register, we should document behavior in both cases with and without the named actor being present.
@classmethod | ||
def get(cls, name: str) -> Callable: | ||
"""Get a function by name.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Similar comment on docstring as above. Pretty sure registries will be used heavily so we should go the extra mile
assert clip_ratio_2 == 0.6 | ||
finally: | ||
# Clean up | ||
if "cross_process_test" in PolicyLossRegistry.list_available(): | ||
PolicyLossRegistry.unregister("cross_process_test") | ||
if "cross_process_test_2" in PolicyLossRegistry.list_available(): | ||
PolicyLossRegistry.unregister("cross_process_test_2") | ||
if "cross_process_adv_test" in AdvantageEstimatorRegistry.list_available(): | ||
AdvantageEstimatorRegistry.unregister("cross_process_adv_test") | ||
PolicyLossRegistry._ray_actor = None | ||
PolicyLossRegistry._synced_to_actor = False | ||
AdvantageEstimatorRegistry._ray_actor = None | ||
AdvantageEstimatorRegistry._synced_to_actor = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't the named actor be cleaned up here (or i'm missing something) ? good time to use the ray init fixture probably.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
added a ray.kill on the actor - the current cpu test ray init fixture is session level so didn't want to call ray.shutdown
here
finally: | ||
# Clean up | ||
if "named_actor_test" in AdvantageEstimatorRegistry.list_available(): | ||
AdvantageEstimatorRegistry.unregister("named_actor_test") | ||
AdvantageEstimatorRegistry._ray_actor = None | ||
AdvantageEstimatorRegistry._synced_to_actor = False |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Another point on cleanup, since ther'es no ray.shutdown call here
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can you add the registry classes to the API reference please?
/gemini review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request introduces a robust registry system for custom advantage estimators and policy loss functions, leveraging Ray named actors for synchronization across processes. This is a significant improvement for the library's extensibility. The refactoring of PolicyLoss
from a class to a function is a great simplification. The changes are well-tested and documented. My review focuses on improving the type hints for the new decorators to fully support custom functions, making the test cleanup more robust, and fixing minor issues in the new documentation.
* [Trainer] Support per-token rewards in trainer (NovaSky-AI#109) * Add check for whether p2p access is supported - allows code to run on L4/L40S after NovaSky-AI#73 upgrade to cuda 12.8 (NovaSky-AI#108) # Overview After NovaSky-AI#73, the main code path no longer runs on GPUs without P2P support (potentially due to cuda 12.8 upgrade?) - an error would be thrown like ```bash torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:3353, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.26.2 ncclUnhandledCudaError: Call to CUDA function failed. Last error: Cuda failure 217 'peer access is not supported between these two devices' ``` This PR adds a check for whether peer access is supported (using torch/cuda) between all GPUs on a node to the ray initialization, and sets relevant NCCL env vars to allow the code to run on these machine types. ```python if not peer_access_supported(): logger.info("Peer access is not supported, disabling P2P and SHM") env_vars["NCCL_P2P_DISABLE"] = "1" env_vars["NCCL_SHM_DISABLE"] = "1" ``` Example running on L40S: <img width="1854" height="227" alt="image" src="https://github.com/user-attachments/assets/1cca46b5-6e16-4ae7-9a33-df52d138bdeb" /> * [dependencies] Upgrade ray to 2.48.0 (NovaSky-AI#106) # What does this PR do Upgrades ray to 2.48.0, which allows us to remove the pip install vllm in the Dockerfile as a fallback for when uv + vllm does not resolve dependencies with the vllm + ray backend correctly. We leave the previous Dockerfile in `docker/Dockerfile.ray244` for backwards compatibility --------- Co-authored-by: Sumanth R Hegde <[email protected]> * fix issue with NovaSky-AI#108 that broke gpu ci (NovaSky-AI#112) missed an argument in `gpu_ci/conftest.py` for `peer_access_supported()` - fix for gpu ci to run Passing now with update: <img width="1811" height="861" alt="image" src="https://github.com/user-attachments/assets/70011c54-1e33-44b5-83a0-616029f891d2" /> And main runs (and disables p2p access) correctly: <img width="2067" height="203" alt="image" src="https://github.com/user-attachments/assets/399aff67-cc51-4588-a632-47698073593c" /> * Add warning for certain uv versions due to `uv run --with` regression (NovaSky-AI#113) # What does this PR do? Adds a warning for uv versions 0.8.0, 0.8.1 and 0.8.2 due to a bug in the uv run --with flag for "Running in ray cluster" section. These are relatively new versions and thus it's better to have this detail in the documentation for users. <img width="692" height="458" alt="Screenshot 2025-07-25 at 6 09 15 PM" src="https://github.com/user-attachments/assets/f1997eac-2867-4552-8ef7-eea8741e32b6" /> <img width="779" height="568" alt="Screenshot 2025-07-25 at 6 09 19 PM" src="https://github.com/user-attachments/assets/5080d328-c934-4864-91a8-932902dea934" /> --------- Signed-off-by: SumanthRH <[email protected]> * [GPU CI] Only trigger workflow for relevant changes in `skyrl-train` (NovaSky-AI#114) * [bug] Loading saved HF weights errors (NovaSky-AI#118) Addresses NovaSky-AI#97 * [DAPO] Add support for overlong filtering (NovaSky-AI#111) ## What does this PR do? Adds `apply_overlong_filtering` to the generator config, and provides a generator utility method `apply_overlong_filtering()` for post-processing the loss mask. I originally implemented this using the `stop_reasons` to determine whether the sequence was truncated, but instead switched to looking for `eos_token` in the response IDs for a more general approach. ## Tests Added CPU tests for the utility method and for SkyRL Gym Generator's use of the utility method. * [skyrl-gym] GSM8k - LLM Judge example (NovaSky-AI#74) * Fix MLFlow logging (NovaSky-AI#121) This is a small change to make the MLFlow integration work. Currently this fails with a Pandas error when trying to flatten an Omega dict; we need to convert to a regular Python dictionary. Can confirm this works on our MLFlow setup: <img width="1406" height="683" alt="image" src="https://github.com/user-attachments/assets/fcee526a-815e-4f08-bf25-d2709779ced7" /> * [Trainer] Support registering custom advantage estimators (NovaSky-AI#115) ## What does this PR do? Adds an `AdvantageEstimatorRegistry` to support custom advantage estimation methods without modifying the skyrl-train package. Added `examples/algorithm/custom_advantage_estimator` folder to give quick example of how to register a custom adv est function. ## Tests Adding cpu test to ensure registration works. * [checkpointing] Add HF model config and tokenizer config to checkpoint folder (NovaSky-AI#124) # Overview Adds the HF model config and tokenizer config to `ckpt_path/huggingface` for deepspeed and fsdp. So now the checkpoint directory will be: ``` {ckpt_path}/ ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint ├── global_step_10/ # Checkpoint at training step 10 │ ├── policy/ # Policy model checkpoint directory │ │ ├── fsdp_config.json # stores fsdp version and world size │ │ ├── huggingface/ │ │ ├── config.json # model config │ │ ├── tokenizer_config.json # tokenizer config │ │ ├── generation_config.json # generation config │ │ ├── ... # other tokenizer config files │ │ ├── model_state.pt # Model parameters │ │ ├── optimizer_state.pt # Optimizer state │ │ └── lr_scheduler_state.pt # Learning rate scheduler state ``` For deepspeed it will be similar but without `fsdp_config.json` ``` {ckpt_path}/ ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint ├── global_step_10/ # Checkpoint at training step 10 │ ├── policy/ # Policy model checkpoint directory │ │ ├── huggingface/ │ │ ├── config.json # model config │ │ ├── tokenizer_config.json # tokenizer config │ │ ├── generation_config.json # generation config │ │ ├── ... # other tokenizer config files │ │ ├── ... # deepspeed checkpointing files ``` * Fix discord link (NovaSky-AI#125) * Fix broken link (NovaSky-AI#128) * [Trainer/Algorithm] Support registering custom policy loss functions + refactor adv estimator registry to allow registration outside ray workers (NovaSky-AI#126) # Overview - Adds support for registering custom policy loss functions, similar to NovaSky-AI#115, - Refactors the policy loss to be a function in `ppo_utils.py` instead of a (`nn.Module` in `worker.py`) - Introduces a breaking change in renaming `trainer.algorithm.ppo_loss_type` to `trainer.algorithm.policy_loss_type` - Addresses Issue NovaSky-AI#116 by creating a new `BaseFunctionRegistry` class that uses a [named actor](https://docs.ray.io/en/latest/ray-core/actors/named-actors.html) to support the following pattern: ```python # Example of custom policy loss: "simple_baseline" def compute_simple_baseline_policy_loss( log_probs: torch.Tensor, ... ): return torch.randn(1, device=log_probs.device), 0.0 # Register the custom policy loss - outside of the ray worker PolicyLossRegistry.register("simple_baseline", compute_simple_baseline_policy_loss) @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: DictConfig): exp = BasePPOExp(cfg) exp.run() @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: DictConfig) -> None: # validate the arguments validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) ``` this change was necessary for `PolicyLossRegistry` to be accessible, since the worker `actor_loss_fn` attribute is set in `init_model` within the `worker` actor, which is a ray actor created from within the skyrl_entrypoint ray task (and registering within the entrypoint wouldn't propagate down another layer). - updates AdvantageEstimatorRegistry to extend the same `BaseFunctionRegistry` class Example runs: Custom advantage (mean of reward) <img width="956" height="326" alt="image" src="https://github.com/user-attachments/assets/1b7222bc-fbb9-49b1-876d-265b71201087" /> Custom policy loss (reinforce - just (-logprobs * advantages).mean()) <img width="939" height="330" alt="image" src="https://github.com/user-attachments/assets/cbed7ef5-b3e7-4e32-beba-b52b80879f47" /> * [SkyAgent] Upload initial refactored code (NovaSky-AI#131) # What does this PR do? Uploading our initial refactored code for SkyAgent --------- Signed-off-by: SumanthRH <[email protected]> Co-authored-by: Shiyi Cao <[email protected]> Co-authored-by: Dacheng Li <[email protected]> * [trainer] add more robust generation output validation (NovaSky-AI#132) # Overview Adds a `validate_generation_output` function in `trainer_utils.py` with more robust validation of generation output format. Specifically, given ``` class GeneratorOutput(TypedDict): prompt_token_ids: List[List[int]] response_ids: List[List[int]] rewards: Union[List[float], List[List[float]]] loss_masks: List[List[int]] stop_reasons: Optional[List[str]] rollout_metrics: Optional[Dict[str, Any]] ``` We expect - all list attributes should have the same length and be the same length as the input batch of prompts at dim=0 - non zero length lists - response_ids, loss masks, and rewards (if token level rewards) should be the same length - the sum of loss masks should be non-zero (logging a warning if it is not) verified gsm8k run still works: <img width="563" height="330" alt="image" src="https://github.com/user-attachments/assets/eeefebcb-d5fc-486d-b906-f4344b1e2779" /> --------- Co-authored-by: Sumanth R Hegde <[email protected]> * [Trainer] GSPO support (NovaSky-AI#120) This PR adds support for [Group Sequence Policy Optimization (GSPO)](https://arxiv.org/abs/2507.18071), the hotness du jour from Alibaba Qwen. The implementation in this PR is loosely based on [this one](huggingface/trl#3775) from TRL. It adds an `importance_sampling_level` config option which can be `token` (PPO/GRPO) or `sequence` (GSPO). I ran a short/small GSM8k run with Qwen2.5-0.5B and the loss curves look okay: <img width="314" height="240" alt="image" src="https://github.com/user-attachments/assets/f52d7c64-416c-4419-aa96-4a03c9048007" /> However, I had to hack a few things to get this to run on Datadog's cloud infra (including changing some dependency versions) so I'd encourage one of the maintainers to reproduce these results locally before merging. * [SkyAgent] Add initial docs (NovaSky-AI#134) # What does this PR do? Adds initial documentation for SkyAgent. We are still actively cleaning this package up, but I thought initial documentation will be helpful for anyone who stumbles across this. The documentation folder is still in `skyrl-train`, and much of the docs also refer to "SkyRL" when they are really referring to "SkyRL-train", so to avoid any confusion, I have just added this as a simple page on the sidebar. We need to make the docs be mono-repo wide and structure it better but I'm leaving it for a future PR. --------- Signed-off-by: SumanthRH <[email protected]> * [trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example (NovaSky-AI#130) # Overview This PR introduces filter (DAPO) and replace (Polaris/WebSailor) style dynamic sampling strategies. The dynamic sampling strategy can be configured as below: ```yaml # dynamic sampling parameters dynamic_sampling: type: null # filter (DAPO), replace (POLARIS/WebSailor), or null max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only) ``` This PR also adds a docs page describing how to enable all DAPO features, and adds an example GSM8K script where all these features are used. ## Minor Changes Some minor changes to make this dynamic sampling implementation clean: - the utils `Timer` class now updates the dict instead of overwriting in order to correctly track generation time w/ dynamic sampling, which means we need to make sure to reset `all_timings` in any trainer - The use of `self.weights_manager` is a little tricky for the dynamic sampling - introduced the the `ConditionalWeightsManager` to make the added code in the training loop as clean as possible ## Example runs <img width="413" height="264" alt="image" src="https://github.com/user-attachments/assets/072f716a-3632-42bb-a5f7-5f9d6064bd93" /> Generation time for dapo style filtering increases as the training run goes on, while it is stable for polaris and the baseline. <img width="419" height="265" alt="image" src="https://github.com/user-attachments/assets/887df550-e4b9-4623-b578-b4809a9f403f" /> We can see that the training pass @ n metric is 1 for both polaris and dapo style filtering as expected. <img width="421" height="259" alt="image" src="https://github.com/user-attachments/assets/bb63af77-1fbb-4d89-9216-b028f1551ea7" /> For GSM8k + Qwen 1.5B, the sampling strategy (as well as the full DAPO run) results in minimal gains - need larger models/harder dataset to test more fully DAPO sampling Example Run: ```bash (skyrl_entrypoint pid=222117) 2025-08-04 23:13:13.439 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:13.737 | INFO | skyrl_train.weights_manager:__enter__:76 - Started: 'sync_weights_to_inference_engines' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.401 | INFO | skyrl_train.weights_manager:__enter__:76 - Finished: 'sync_weights_to_inference_engines', time cost: 2.66s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.401 | INFO | skyrl_train.weights_manager:__enter__:80 - Started: 'offload_policy_model_to_cpu' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.842 | INFO | skyrl_train.weights_manager:__enter__:80 - Finished: 'offload_policy_model_to_cpu', time cost: 0.44s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.888 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:13 [executor_base.py:227] It took 0.243244 seconds to wake up tags ['weights']. [repeated 4x across cluster] (AsyncVLLMInferenceEngine pid=223854) INFO 08-04 23:13:16 [executor_base.py:227] It took 0.040547 seconds to wake up tags ['kv_cache']. (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:16 [block_pool.py:316] Successfully reset prefix cache [repeated 7x across cluster] (AsyncVLLMInferenceEngine pid=223855) INFO 08-04 23:13:16 [executor_base.py:227] It took 0.041721 seconds to wake up tags ['kv_cache']. (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.378 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 17.49s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:433 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:434 - Dynamic sampling: 460 < 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:435 - Resample batch 1, continue sampling... (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:436 - ================================================== (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.trainer:train:245 - Finished: 'step', time cost: 20.96s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.407 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.445 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.014 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 17.57s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:433 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:434 - Dynamic sampling: 941 < 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:435 - Resample batch 2, continue sampling... (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:436 - ================================================== (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.030 | INFO | skyrl_train.trainer:train:245 - Finished: 'step', time cost: 17.62s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.033 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.074 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.380 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 16.31s (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.396 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:439 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.396 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:440 - Dynamic sampling: collected 1467 >= 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.397 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:443 - ================================================== (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:12 [gpu_worker.py:98] Sleep mode freed 61.88 GiB memory, 4.98 GiB memory is still in use. [repeated 3x across cluster] (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:12 [executor_base.py:211] It took 1.264572 seconds to fall asleep. [repeated 3x across cluster] ``` Polaris Style example run: ```bash (skyrl_entrypoint pid=306764) 2025-08-05 00:30:01.648 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (AsyncVLLMInferenceEngine pid=308521) INFO 08-05 00:29:58 [executor_base.py:227] It took 0.240372 seconds to wake up tags ['weights']. [repeated 4x across cluster] (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:30:01 [executor_base.py:227] It took 0.040980 seconds to wake up tags ['kv_cache']. (AsyncVLLMInferenceEngine pid=308521) INFO 08-05 00:30:00 [block_pool.py:316] Successfully reset prefix cache [repeated 7x across cluster] (AsyncVLLMInferenceEngine pid=308518) INFO 08-05 00:30:01 [executor_base.py:227] It took 0.041175 seconds to wake up tags ['kv_cache']. (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.663 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 15.01s (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.679 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:316 - Replace sampling: 629 good UIDs out of 1024 total prompts (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:320 - ============= Dynamic sampling replace =========== (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:321 - Number of good prompts: 629 (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:322 - Number of bad prompts: 395 (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.694 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:352 - After replacement - Replaced 395 bad prompts (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.694 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:353 - ================================================== (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:29:57 [gpu_worker.py:98] Sleep mode freed 62.14 GiB memory, 6.28 GiB memory is still in use. [repeated 3x across cluster] (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:29:57 [executor_base.py:211] It took 1.331663 seconds to fall asleep. ``` ## Full DAPO example run From example script <img width="417" height="262" alt="image" src="https://github.com/user-attachments/assets/2592a06f-8b8a-4cf1-a29e-321bff819eb0" /> <img width="909" height="325" alt="image" src="https://github.com/user-attachments/assets/50922afd-1424-4183-9329-4f1f340287eb" /> --------- Co-authored-by: Sumanth R Hegde <[email protected]> * [algorithm] Support Dr. GRPO + refactor where policy/critic loss functions are set (NovaSky-AI#133) # Overview ## Dr GRPO Adds `loss_reduction`: `seq_mean_token_sum_norm ` option, and `grpo_norm_by_std` option to support Dr. GRPO So to run Dr. GRPO, set: ```yaml trainer: algorithm: grpo_norm_by_std: false loss_reduction: "seq_mean_token_sum_norm" ... ``` Example run: <img width="906" height="317" alt="image" src="https://github.com/user-attachments/assets/ce9db2ef-253e-45c8-adba-1ef8a270bbd9" /> Reward looks similar <img width="419" height="263" alt="image" src="https://github.com/user-attachments/assets/a4bc4d8c-f3c1-4bad-a497-0297dc30bc27" /> Magnitude of policy loss is lower as expected (since we are normalizing by a larger constant rather than taking the mean) ## Refactor where Critic/Policy Loss are set Changes ppo critic `ValueLoss` to just a function instead of a `nn.Module` for consistency with `policy_loss`, and adds new algorithm field to cfg that require evaluating field values in `utils::validate_cfg` (this runs before entrypoint code, allowing users to modify the cfg further by subclassing `BasePPOExp`) PPO example still running after this refactor: <img width="421" height="262" alt="image" src="https://github.com/user-attachments/assets/88985da3-1403-49c6-8cb5-f1434151fd9e" /> * [fix] move algorithm folder -> algorithms (NovaSky-AI#136) left the algorithm folder in NovaSky-AI#133, move it over * [Logging] Forward mlflow env vars to ray runtime env (NovaSky-AI#135) This PR forward the `MLFLOW_TRACKING_URI` and `MLFLOW_TRACKING_TOKEN` environment variable to the ray runtime env during its initialization. This will enable users to simply provide the above env vars at the driver and be able to use MLFlow for experiment tracking. * data folder * some stuff * updates --------- Signed-off-by: SumanthRH <[email protected]> Co-authored-by: Sumanth R Hegde <[email protected]> Co-authored-by: Eric Tang <[email protected]> Co-authored-by: Tyler Griggs <[email protected]> Co-authored-by: Shu Liu <[email protected]> Co-authored-by: Ben Cohen <[email protected]> Co-authored-by: Shiyi Cao <[email protected]> Co-authored-by: Dacheng Li <[email protected]> Co-authored-by: Etienne Brodu <[email protected]>
* [Trainer] Support per-token rewards in trainer (NovaSky-AI#109) * Add check for whether p2p access is supported - allows code to run on L4/L40S after NovaSky-AI#73 upgrade to cuda 12.8 (NovaSky-AI#108) # Overview After NovaSky-AI#73, the main code path no longer runs on GPUs without P2P support (potentially due to cuda 12.8 upgrade?) - an error would be thrown like ```bash torch.distributed.DistBackendError: NCCL error in: /pytorch/torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:3353, unhandled cuda error (run with NCCL_DEBUG=INFO for details), NCCL version 2.26.2 ncclUnhandledCudaError: Call to CUDA function failed. Last error: Cuda failure 217 'peer access is not supported between these two devices' ``` This PR adds a check for whether peer access is supported (using torch/cuda) between all GPUs on a node to the ray initialization, and sets relevant NCCL env vars to allow the code to run on these machine types. ```python if not peer_access_supported(): logger.info("Peer access is not supported, disabling P2P and SHM") env_vars["NCCL_P2P_DISABLE"] = "1" env_vars["NCCL_SHM_DISABLE"] = "1" ``` Example running on L40S: <img width="1854" height="227" alt="image" src="https://github.com/user-attachments/assets/1cca46b5-6e16-4ae7-9a33-df52d138bdeb" /> * [dependencies] Upgrade ray to 2.48.0 (NovaSky-AI#106) # What does this PR do Upgrades ray to 2.48.0, which allows us to remove the pip install vllm in the Dockerfile as a fallback for when uv + vllm does not resolve dependencies with the vllm + ray backend correctly. We leave the previous Dockerfile in `docker/Dockerfile.ray244` for backwards compatibility --------- Co-authored-by: Sumanth R Hegde <[email protected]> * fix issue with NovaSky-AI#108 that broke gpu ci (NovaSky-AI#112) missed an argument in `gpu_ci/conftest.py` for `peer_access_supported()` - fix for gpu ci to run Passing now with update: <img width="1811" height="861" alt="image" src="https://github.com/user-attachments/assets/70011c54-1e33-44b5-83a0-616029f891d2" /> And main runs (and disables p2p access) correctly: <img width="2067" height="203" alt="image" src="https://github.com/user-attachments/assets/399aff67-cc51-4588-a632-47698073593c" /> * Add warning for certain uv versions due to `uv run --with` regression (NovaSky-AI#113) # What does this PR do? Adds a warning for uv versions 0.8.0, 0.8.1 and 0.8.2 due to a bug in the uv run --with flag for "Running in ray cluster" section. These are relatively new versions and thus it's better to have this detail in the documentation for users. <img width="692" height="458" alt="Screenshot 2025-07-25 at 6 09 15 PM" src="https://github.com/user-attachments/assets/f1997eac-2867-4552-8ef7-eea8741e32b6" /> <img width="779" height="568" alt="Screenshot 2025-07-25 at 6 09 19 PM" src="https://github.com/user-attachments/assets/5080d328-c934-4864-91a8-932902dea934" /> --------- Signed-off-by: SumanthRH <[email protected]> * [GPU CI] Only trigger workflow for relevant changes in `skyrl-train` (NovaSky-AI#114) * [bug] Loading saved HF weights errors (NovaSky-AI#118) Addresses NovaSky-AI#97 * [DAPO] Add support for overlong filtering (NovaSky-AI#111) ## What does this PR do? Adds `apply_overlong_filtering` to the generator config, and provides a generator utility method `apply_overlong_filtering()` for post-processing the loss mask. I originally implemented this using the `stop_reasons` to determine whether the sequence was truncated, but instead switched to looking for `eos_token` in the response IDs for a more general approach. ## Tests Added CPU tests for the utility method and for SkyRL Gym Generator's use of the utility method. * [skyrl-gym] GSM8k - LLM Judge example (NovaSky-AI#74) * Fix MLFlow logging (NovaSky-AI#121) This is a small change to make the MLFlow integration work. Currently this fails with a Pandas error when trying to flatten an Omega dict; we need to convert to a regular Python dictionary. Can confirm this works on our MLFlow setup: <img width="1406" height="683" alt="image" src="https://github.com/user-attachments/assets/fcee526a-815e-4f08-bf25-d2709779ced7" /> * [Trainer] Support registering custom advantage estimators (NovaSky-AI#115) ## What does this PR do? Adds an `AdvantageEstimatorRegistry` to support custom advantage estimation methods without modifying the skyrl-train package. Added `examples/algorithm/custom_advantage_estimator` folder to give quick example of how to register a custom adv est function. ## Tests Adding cpu test to ensure registration works. * [checkpointing] Add HF model config and tokenizer config to checkpoint folder (NovaSky-AI#124) # Overview Adds the HF model config and tokenizer config to `ckpt_path/huggingface` for deepspeed and fsdp. So now the checkpoint directory will be: ``` {ckpt_path}/ ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint ├── global_step_10/ # Checkpoint at training step 10 │ ├── policy/ # Policy model checkpoint directory │ │ ├── fsdp_config.json # stores fsdp version and world size │ │ ├── huggingface/ │ │ ├── config.json # model config │ │ ├── tokenizer_config.json # tokenizer config │ │ ├── generation_config.json # generation config │ │ ├── ... # other tokenizer config files │ │ ├── model_state.pt # Model parameters │ │ ├── optimizer_state.pt # Optimizer state │ │ └── lr_scheduler_state.pt # Learning rate scheduler state ``` For deepspeed it will be similar but without `fsdp_config.json` ``` {ckpt_path}/ ├── latest_ckpt_global_step.txt # Holds the global step of the latest checkpoint ├── global_step_10/ # Checkpoint at training step 10 │ ├── policy/ # Policy model checkpoint directory │ │ ├── huggingface/ │ │ ├── config.json # model config │ │ ├── tokenizer_config.json # tokenizer config │ │ ├── generation_config.json # generation config │ │ ├── ... # other tokenizer config files │ │ ├── ... # deepspeed checkpointing files ``` * Fix discord link (NovaSky-AI#125) * Fix broken link (NovaSky-AI#128) * [Trainer/Algorithm] Support registering custom policy loss functions + refactor adv estimator registry to allow registration outside ray workers (NovaSky-AI#126) # Overview - Adds support for registering custom policy loss functions, similar to NovaSky-AI#115, - Refactors the policy loss to be a function in `ppo_utils.py` instead of a (`nn.Module` in `worker.py`) - Introduces a breaking change in renaming `trainer.algorithm.ppo_loss_type` to `trainer.algorithm.policy_loss_type` - Addresses Issue NovaSky-AI#116 by creating a new `BaseFunctionRegistry` class that uses a [named actor](https://docs.ray.io/en/latest/ray-core/actors/named-actors.html) to support the following pattern: ```python # Example of custom policy loss: "simple_baseline" def compute_simple_baseline_policy_loss( log_probs: torch.Tensor, ... ): return torch.randn(1, device=log_probs.device), 0.0 # Register the custom policy loss - outside of the ray worker PolicyLossRegistry.register("simple_baseline", compute_simple_baseline_policy_loss) @ray.remote(num_cpus=1) def skyrl_entrypoint(cfg: DictConfig): exp = BasePPOExp(cfg) exp.run() @hydra.main(config_path=config_dir, config_name="ppo_base_config", version_base=None) def main(cfg: DictConfig) -> None: # validate the arguments validate_cfg(cfg) initialize_ray(cfg) ray.get(skyrl_entrypoint.remote(cfg)) ``` this change was necessary for `PolicyLossRegistry` to be accessible, since the worker `actor_loss_fn` attribute is set in `init_model` within the `worker` actor, which is a ray actor created from within the skyrl_entrypoint ray task (and registering within the entrypoint wouldn't propagate down another layer). - updates AdvantageEstimatorRegistry to extend the same `BaseFunctionRegistry` class Example runs: Custom advantage (mean of reward) <img width="956" height="326" alt="image" src="https://github.com/user-attachments/assets/1b7222bc-fbb9-49b1-876d-265b71201087" /> Custom policy loss (reinforce - just (-logprobs * advantages).mean()) <img width="939" height="330" alt="image" src="https://github.com/user-attachments/assets/cbed7ef5-b3e7-4e32-beba-b52b80879f47" /> * [SkyAgent] Upload initial refactored code (NovaSky-AI#131) # What does this PR do? Uploading our initial refactored code for SkyAgent --------- Signed-off-by: SumanthRH <[email protected]> Co-authored-by: Shiyi Cao <[email protected]> Co-authored-by: Dacheng Li <[email protected]> * [trainer] add more robust generation output validation (NovaSky-AI#132) # Overview Adds a `validate_generation_output` function in `trainer_utils.py` with more robust validation of generation output format. Specifically, given ``` class GeneratorOutput(TypedDict): prompt_token_ids: List[List[int]] response_ids: List[List[int]] rewards: Union[List[float], List[List[float]]] loss_masks: List[List[int]] stop_reasons: Optional[List[str]] rollout_metrics: Optional[Dict[str, Any]] ``` We expect - all list attributes should have the same length and be the same length as the input batch of prompts at dim=0 - non zero length lists - response_ids, loss masks, and rewards (if token level rewards) should be the same length - the sum of loss masks should be non-zero (logging a warning if it is not) verified gsm8k run still works: <img width="563" height="330" alt="image" src="https://github.com/user-attachments/assets/eeefebcb-d5fc-486d-b906-f4344b1e2779" /> --------- Co-authored-by: Sumanth R Hegde <[email protected]> * [Trainer] GSPO support (NovaSky-AI#120) This PR adds support for [Group Sequence Policy Optimization (GSPO)](https://arxiv.org/abs/2507.18071), the hotness du jour from Alibaba Qwen. The implementation in this PR is loosely based on [this one](huggingface/trl#3775) from TRL. It adds an `importance_sampling_level` config option which can be `token` (PPO/GRPO) or `sequence` (GSPO). I ran a short/small GSM8k run with Qwen2.5-0.5B and the loss curves look okay: <img width="314" height="240" alt="image" src="https://github.com/user-attachments/assets/f52d7c64-416c-4419-aa96-4a03c9048007" /> However, I had to hack a few things to get this to run on Datadog's cloud infra (including changing some dependency versions) so I'd encourage one of the maintainers to reproduce these results locally before merging. * [SkyAgent] Add initial docs (NovaSky-AI#134) # What does this PR do? Adds initial documentation for SkyAgent. We are still actively cleaning this package up, but I thought initial documentation will be helpful for anyone who stumbles across this. The documentation folder is still in `skyrl-train`, and much of the docs also refer to "SkyRL" when they are really referring to "SkyRL-train", so to avoid any confusion, I have just added this as a simple page on the sidebar. We need to make the docs be mono-repo wide and structure it better but I'm leaving it for a future PR. --------- Signed-off-by: SumanthRH <[email protected]> * [trainer/algorithm] Implement DAPO and Polaris style dynamic sampling + add DAPO docs + example (NovaSky-AI#130) # Overview This PR introduces filter (DAPO) and replace (Polaris/WebSailor) style dynamic sampling strategies. The dynamic sampling strategy can be configured as below: ```yaml # dynamic sampling parameters dynamic_sampling: type: null # filter (DAPO), replace (POLARIS/WebSailor), or null max_sample_batches: 30 # sample at most this many batches before stopping, -1 to sample forever min_replace_ratio: 0.3 # minimum proportion of good samples with which to replace bad samples (for replace strategy only) ``` This PR also adds a docs page describing how to enable all DAPO features, and adds an example GSM8K script where all these features are used. ## Minor Changes Some minor changes to make this dynamic sampling implementation clean: - the utils `Timer` class now updates the dict instead of overwriting in order to correctly track generation time w/ dynamic sampling, which means we need to make sure to reset `all_timings` in any trainer - The use of `self.weights_manager` is a little tricky for the dynamic sampling - introduced the the `ConditionalWeightsManager` to make the added code in the training loop as clean as possible ## Example runs <img width="413" height="264" alt="image" src="https://github.com/user-attachments/assets/072f716a-3632-42bb-a5f7-5f9d6064bd93" /> Generation time for dapo style filtering increases as the training run goes on, while it is stable for polaris and the baseline. <img width="419" height="265" alt="image" src="https://github.com/user-attachments/assets/887df550-e4b9-4623-b578-b4809a9f403f" /> We can see that the training pass @ n metric is 1 for both polaris and dapo style filtering as expected. <img width="421" height="259" alt="image" src="https://github.com/user-attachments/assets/bb63af77-1fbb-4d89-9216-b028f1551ea7" /> For GSM8k + Qwen 1.5B, the sampling strategy (as well as the full DAPO run) results in minimal gains - need larger models/harder dataset to test more fully DAPO sampling Example Run: ```bash (skyrl_entrypoint pid=222117) 2025-08-04 23:13:13.439 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:13.737 | INFO | skyrl_train.weights_manager:__enter__:76 - Started: 'sync_weights_to_inference_engines' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.401 | INFO | skyrl_train.weights_manager:__enter__:76 - Finished: 'sync_weights_to_inference_engines', time cost: 2.66s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.401 | INFO | skyrl_train.weights_manager:__enter__:80 - Started: 'offload_policy_model_to_cpu' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.842 | INFO | skyrl_train.weights_manager:__enter__:80 - Finished: 'offload_policy_model_to_cpu', time cost: 0.44s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:16.888 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:13 [executor_base.py:227] It took 0.243244 seconds to wake up tags ['weights']. [repeated 4x across cluster] (AsyncVLLMInferenceEngine pid=223854) INFO 08-04 23:13:16 [executor_base.py:227] It took 0.040547 seconds to wake up tags ['kv_cache']. (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:16 [block_pool.py:316] Successfully reset prefix cache [repeated 7x across cluster] (AsyncVLLMInferenceEngine pid=223855) INFO 08-04 23:13:16 [executor_base.py:227] It took 0.041721 seconds to wake up tags ['kv_cache']. (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.378 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 17.49s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:433 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:434 - Dynamic sampling: 460 < 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:435 - Resample batch 1, continue sampling... (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:436 - ================================================== (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.395 | INFO | skyrl_train.trainer:train:245 - Finished: 'step', time cost: 20.96s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.407 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:34.445 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.014 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 17.57s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:433 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:434 - Dynamic sampling: 941 < 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:435 - Resample batch 2, continue sampling... (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.029 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:436 - ================================================== (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.030 | INFO | skyrl_train.trainer:train:245 - Finished: 'step', time cost: 17.62s (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.033 | INFO | skyrl_train.trainer:train:245 - Started: 'step' (skyrl_entrypoint pid=222117) 2025-08-04 23:13:52.074 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.380 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 16.31s (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.396 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:439 - ============= Dynamic sampling filter ============= (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.396 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:440 - Dynamic sampling: collected 1467 >= 1024 prompts (skyrl_entrypoint pid=222117) 2025-08-04 23:14:08.397 | INFO | skyrl_train.utils.trainer_utils:handle_filter_sampling:443 - ================================================== (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:12 [gpu_worker.py:98] Sleep mode freed 61.88 GiB memory, 4.98 GiB memory is still in use. [repeated 3x across cluster] (AsyncVLLMInferenceEngine pid=223856) INFO 08-04 23:13:12 [executor_base.py:211] It took 1.264572 seconds to fall asleep. [repeated 3x across cluster] ``` Polaris Style example run: ```bash (skyrl_entrypoint pid=306764) 2025-08-05 00:30:01.648 | INFO | skyrl_train.trainer:train:261 - Started: 'generate' (AsyncVLLMInferenceEngine pid=308521) INFO 08-05 00:29:58 [executor_base.py:227] It took 0.240372 seconds to wake up tags ['weights']. [repeated 4x across cluster] (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:30:01 [executor_base.py:227] It took 0.040980 seconds to wake up tags ['kv_cache']. (AsyncVLLMInferenceEngine pid=308521) INFO 08-05 00:30:00 [block_pool.py:316] Successfully reset prefix cache [repeated 7x across cluster] (AsyncVLLMInferenceEngine pid=308518) INFO 08-05 00:30:01 [executor_base.py:227] It took 0.041175 seconds to wake up tags ['kv_cache']. (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.663 | INFO | skyrl_train.trainer:train:261 - Finished: 'generate', time cost: 15.01s (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.679 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:316 - Replace sampling: 629 good UIDs out of 1024 total prompts (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:320 - ============= Dynamic sampling replace =========== (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:321 - Number of good prompts: 629 (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.680 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:322 - Number of bad prompts: 395 (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.694 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:352 - After replacement - Replaced 395 bad prompts (skyrl_entrypoint pid=306764) 2025-08-05 00:30:16.694 | INFO | skyrl_train.utils.trainer_utils:handle_replace_sampling:353 - ================================================== (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:29:57 [gpu_worker.py:98] Sleep mode freed 62.14 GiB memory, 6.28 GiB memory is still in use. [repeated 3x across cluster] (AsyncVLLMInferenceEngine pid=308520) INFO 08-05 00:29:57 [executor_base.py:211] It took 1.331663 seconds to fall asleep. ``` ## Full DAPO example run From example script <img width="417" height="262" alt="image" src="https://github.com/user-attachments/assets/2592a06f-8b8a-4cf1-a29e-321bff819eb0" /> <img width="909" height="325" alt="image" src="https://github.com/user-attachments/assets/50922afd-1424-4183-9329-4f1f340287eb" /> --------- Co-authored-by: Sumanth R Hegde <[email protected]> * [algorithm] Support Dr. GRPO + refactor where policy/critic loss functions are set (NovaSky-AI#133) # Overview ## Dr GRPO Adds `loss_reduction`: `seq_mean_token_sum_norm ` option, and `grpo_norm_by_std` option to support Dr. GRPO So to run Dr. GRPO, set: ```yaml trainer: algorithm: grpo_norm_by_std: false loss_reduction: "seq_mean_token_sum_norm" ... ``` Example run: <img width="906" height="317" alt="image" src="https://github.com/user-attachments/assets/ce9db2ef-253e-45c8-adba-1ef8a270bbd9" /> Reward looks similar <img width="419" height="263" alt="image" src="https://github.com/user-attachments/assets/a4bc4d8c-f3c1-4bad-a497-0297dc30bc27" /> Magnitude of policy loss is lower as expected (since we are normalizing by a larger constant rather than taking the mean) ## Refactor where Critic/Policy Loss are set Changes ppo critic `ValueLoss` to just a function instead of a `nn.Module` for consistency with `policy_loss`, and adds new algorithm field to cfg that require evaluating field values in `utils::validate_cfg` (this runs before entrypoint code, allowing users to modify the cfg further by subclassing `BasePPOExp`) PPO example still running after this refactor: <img width="421" height="262" alt="image" src="https://github.com/user-attachments/assets/88985da3-1403-49c6-8cb5-f1434151fd9e" /> * [fix] move algorithm folder -> algorithms (NovaSky-AI#136) left the algorithm folder in NovaSky-AI#133, move it over * [Logging] Forward mlflow env vars to ray runtime env (NovaSky-AI#135) This PR forward the `MLFLOW_TRACKING_URI` and `MLFLOW_TRACKING_TOKEN` environment variable to the ray runtime env during its initialization. This will enable users to simply provide the above env vars at the driver and be able to use MLFlow for experiment tracking. * data folder * some stuff * updates --------- Signed-off-by: SumanthRH <[email protected]> Co-authored-by: Sumanth R Hegde <[email protected]> Co-authored-by: Eric Tang <[email protected]> Co-authored-by: Tyler Griggs <[email protected]> Co-authored-by: Shu Liu <[email protected]> Co-authored-by: Ben Cohen <[email protected]> Co-authored-by: Shiyi Cao <[email protected]> Co-authored-by: Dacheng Li <[email protected]> Co-authored-by: Etienne Brodu <[email protected]>
…d training (#161) # What does this PR do? Supports a list of weights during weight sync for colocated training. During colocated training, we use CUDA IPC for weight syncing. The current impl is syncing weights param by param, which can be pretty inefficient. In this PR, we sycn tensors in batches of a configurable parameter (default 1GB). That is, we collect ipc metadata until the total size of underlying tensors is 1GB and forward to the inference engine. Each TP rank will materialize all tensors in this list (i.e additional memory usage of 1GB here) and issue a single load_weights call. **How much faster is it?** Even for a 14B model on a 8xH100 node (TP2), the weight sync time can reduce from around 4.4s to 1.6s (60% reduction). This will matter much more for larger models. This PR is needed for the FlashRL integration to work well, because we have a custom load weights impl that - long story short - allcoates new storage in each call and also issues some `empty_cache` calls. Without batching, the load weights call will be too slow in such cases. This PR reduces time for weight sync for a 1.5B model with flashrl from 5 mins to < 5s. I've tested the PR with our E2E tests for colocated and non-colocated and also tested the remote engine codepath. This PR also makes the following changes: - Fixes bug introduced in #145 for the codepath with trajectory based routing when `response_ids` is not returned by the engine. - Fixes bug introduced in #126 for starting remote servers. import of `skyrl_train.utils.ppo_utils` will trigger registering. IN some cases, like with the vllm server init, we will not call `sync_registries` and there will be an error. The solution is to import guard `skyrl_train.utils.ppo_utils` unless the user themselves import it (for custom functions) or they go through the main entrypoint ( main -> `initialize_ray`-> sync) TODO: - [x] Verify non-colocated training works - [x] Run e2e test --------- Signed-off-by: SumanthRH <[email protected]>
Overview
ppo_utils.py
instead of a (nn.Module
inworker.py
)trainer.algorithm.ppo_loss_type
totrainer.algorithm.policy_loss_type
BaseFunctionRegistry
class that uses a named actor to support the following pattern:this change was necessary for
PolicyLossRegistry
to be accessible, since the workeractor_loss_fn
attribute is set ininit_model
within theworker
actor, which is a ray actor created from within the skyrl_entrypoint ray task (and registering within the entrypoint wouldn't propagate down another layer).BaseFunctionRegistry
classExample runs:

Custom advantage (mean of reward)
Custom policy loss (reinforce - just (-logprobs * advantages).mean())
