Skip to content

Conversation

@ryang-max
Copy link

@ryang-max ryang-max commented Apr 27, 2025

What does this PR do?

This PR add SGLang as an alternative rollout engine when training using GRPO.

Reference PR: #3094

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a GitHub issue? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes?
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

TODO

  • Add documentation on using sglang as rollout engine
  • Performance comparison with vllm (Some preliminary result: precision is the same; update_weights faster than vllm; generate a bit slower which needs more profiling)
  • Implement and test deepspeed3 support same as vllm

@ryang-max ryang-max changed the title [Feat] Suppport SGLang as rollout server of GRPO trainer [Feat] Suppport SGLang as rollout engine of GRPO trainer Apr 27, 2025
@ryang-max
Copy link
Author

ryang-max commented Apr 28, 2025

Sometimes there's strange error in the training process:

[rank0]:   File "/sgl-workspace/ryang/trl/trl/scripts/grpo_test/grpo_sgl_test.py", line 63, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2245, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2560, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3730, in training_step
[rank0]:     inputs = self._prepare_inputs(inputs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 964, in _prepare_inputs
[rank0]:     accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 1132, in _generate_and_score_completions
[rank0]:     ref_per_token_logps = self._get_per_token_logps(
[rank0]:   File "/sgl-workspace/ryang/trl/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 859, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 823, in forward
[rank0]:     outputs: BaseModelOutputWithPast = self.model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 505, in forward
[rank0]:     inputs_embeds = self.embed_tokens(input_ids)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank0]:     return F.embedding(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2551, in embedding
[rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]: RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)
[rank0]: Traceback (most recent call last):
[rank0]:   File "/sgl-workspace/ryang/trl/trl/scripts/grpo_test/grpo_sgl_test.py", line 63, in <module>
[rank0]:     trainer.train()
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2245, in train
[rank0]:     return inner_training_loop(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 2560, in _inner_training_loop
[rank0]:     tr_loss_step = self.training_step(model, inputs, num_items_in_batch)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/trainer.py", line 3730, in training_step
[rank0]:     inputs = self._prepare_inputs(inputs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 964, in _prepare_inputs
[rank0]:     accumulated_local_batch = self._generate_and_score_completions(accumulated_local_batch)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 1132, in _generate_and_score_completions
[rank0]:     ref_per_token_logps = self._get_per_token_logps(
[rank0]:   File "/sgl-workspace/ryang/trl/trl/extras/profiling.py", line 87, in wrapper
[rank0]:     return func(self, *args, **kwargs)
[rank0]:   File "/sgl-workspace/ryang/trl/trl/trainer/grpo_trainer.py", line 859, in _get_per_token_logps
[rank0]:     logits = model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 819, in forward
[rank0]:     return model_forward(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/accelerate/utils/operations.py", line 807, in __call__
[rank0]:     return convert_to_fp32(self.model_forward(*args, **kwargs))
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/deprecation.py", line 172, in wrapped_func
[rank0]:     return func(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 823, in forward
[rank0]:     outputs: BaseModelOutputWithPast = self.model(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/utils/generic.py", line 965, in wrapper
[rank0]:     output = func(self, *args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/transformers/models/qwen2/modeling_qwen2.py", line 505, in forward
[rank0]:     inputs_embeds = self.embed_tokens(input_ids)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1739, in _wrapped_call_impl
[rank0]:     return self._call_impl(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1750, in _call_impl
[rank0]:     return forward_call(*args, **kwargs)
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/sparse.py", line 190, in forward
[rank0]:     return F.embedding(
[rank0]:   File "/usr/local/lib/python3.10/dist-packages/torch/nn/functional.py", line 2551, in embedding
[rank0]:     return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
[rank0]: RuntimeError: Expected tensor for argument #1 'indices' to have one of the following scalar types: Long, Int; but got torch.cuda.FloatTensor instead (while checking arguments for embedding)

which need further investigation

@kashif
Copy link
Collaborator

kashif commented Apr 28, 2025

@ryang-max, why did you change some of the help strings for the different options?

@ryang-max
Copy link
Author

@ryang-max, why did you change some of the help strings for the different options?

Thanks, that's unexpected (maybe something modified in development process), I'll change them back

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In the final version, the files in scripts should be removed

@ryang-max
Copy link
Author

ryang-max commented May 6, 2025

Sorry for the late reply because I'm on vacation these days and not noticed the comments. Thanks for the help from @kashif cleaning up the code! @renxinx Could you please help follow this PR to merge?

@renxinx
Copy link

renxinx commented May 7, 2025

hi @kashif files in scripts were removed

@kashif
Copy link
Collaborator

kashif commented May 7, 2025

shall we also remove the commented out code?

@kashif
Copy link
Collaborator

kashif commented May 7, 2025

@ryang-max shall we also add a sglang-server cli script?

@kashif
Copy link
Collaborator

kashif commented May 13, 2025

@ryang-max I have added a TRL CLI command trl sglang-serve in order to start the sglang server

@renxinx
Copy link

renxinx commented May 22, 2025

@kashif Can you please review and see if we can merge or need further modification? thx 🙏
Used 1% of trl-lib/tldr dataset for training and evaluation. Or is there other ways for quick verification

Model ROUGE-1 ROUGE-2 ROUGE-L
Qwen2-0.5B (base) 0.1243 0.0471 0.0890
Fine-tuned (GRPO) 0.1448 0.0549 0.1034

@kashif
Copy link
Collaborator

kashif commented May 22, 2025

thanks @ryang-max btw can you add the instructions how you ran the benchmark or the sglang run to the docs? in the GRPO docs' sglang section 🙇🏽

@qgallouedec
Copy link
Member

The main question remains: how is weight transfer managed? I understand there's serialization. It must be super slow on the big models, right? Have you benchmarked with a 30B+?

@renxinx
Copy link

renxinx commented May 24, 2025

@qgallouedec
We use Checkpoint-based Updates:

  1. Update the GRPOConfig by adding a checkpoint_path parameter.
  2. Write model checkpoints at regular intervals.
  3. Use the existing /update_weights_from_disk endpoint provided by the SGLang server.

This approach avoids modifying SGLang’s internal initialization routines and leverages its existing, stable checkpoint-loading capabilities.

We only used 7b for benchmark.

@renxinx
Copy link

renxinx commented May 24, 2025

@qgallouedec I've added how to run sglang in docs.

@qgallouedec
Copy link
Member

Thanks, but my question about speed remains: how much time does it take to update the model? Do you have a benchmark to share? A wandb run?

@ryang-max
Copy link
Author

ryang-max commented May 26, 2025

Thanks, but my question about speed remains: how much time does it take to update the model? Do you have a benchmark to share? A wandb run?

The question you mentioned is reasonable. Actually I just found it several days ago. Although we are using the sglang server update_weights_from_tensor API, we passed the serialization of the full tensor(including data); however, the expected use case of this API should only serialize metadata and the actual tensor data should be directly loaded from GPU.

Currently we ran testcase on 0.5B model and it shows explicit increasement in rollout comparing with vllm-based rollout. (@renxinx Would you provide some comparison result about this? ) And we can enhance this update_weights in the future to only pass metadata and maybe use udpate_weights_from_distributed for multi-node large model training.

Thanks! @qgallouedec

@qgallouedec
Copy link
Member

closing in favour of #3627

@qgallouedec qgallouedec closed this Nov 5, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

6 participants