Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
FROM anyscale/ray:2.44.0-slim-py312-cu124
FROM anyscale/ray:2.44.0-slim-py312-cu128

RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev

# the cuda compiler here is needed for deepspeed
RUN wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run
RUN sudo sh cuda_12.4.0_550.54.14_linux.run --silent --toolkit
RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \
&& sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run

RUN curl -LsSf https://astral.sh/uv/install.sh | sh
RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc

RUN sudo apt-get update \
&& sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata
RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*

RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \
&& sudo apt-get clean \
&& sudo rm -rf /var/lib/apt/lists/*

Comment on lines 3 to +20
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Multiple RUN apt-get update and cleanup layers can be consolidated into a single RUN block to reduce image layers and overall size.

Copilot uses AI. Check for mistakes.

# NOTE: vllm installation in base environment is needed for uv + vLLM to work
RUN pip install vllm==0.8.5
RUN pip install ray==2.44.0
RUN pip install vllm==0.9.2 \
&& pip install ray==2.44.0 \
&& rm -rf ~/.cache/pip
2 changes: 1 addition & 1 deletion skyrl-train/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ A quick start guide for installation and your first training run is provided bel

The only requirements are:

- CUDA version >=12.4
- CUDA version 12.8
- [uv](https://docs.astral.sh/uv/)

If you're running on an existing Ray cluster, make sure to use Ray 2.44.0 and Python 3.12. If not, proceed with the installation instructions below.
Expand Down
2 changes: 1 addition & 1 deletion skyrl-train/docs/configuration/config.rst
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,7 @@ Generator Configuration
min_p: 0.0
top_k: -1

use_conversation_multi_turn: false
use_conversation_multi_turn: true

# sampling params for evaluation
eval_sampling_params:
Expand Down
2 changes: 1 addition & 1 deletion skyrl-train/docs/examples/multi_turn_text2sql.rst
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,7 @@ Now that we have our dataset and database files, let's walk through the some of
- Chat templating and loss masking for multi-turn conversations are handled by the ``SkyRLGymGenerator`` class.

- In the above example, we set ``use_conversation_multi_turn=false`` to enforce that the multi-turn conversation is formatted as a single assistant response.
- If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn.
- If you want to use a conversation-based format, you can set ``use_conversation_multi_turn=true`` and the model will generate a separate assistant response for each turn. This is supported only with ``backend="vllm"`` as of now.
- See :code_link:`skyrl_train/generators/skyrl_gym_generator.py` for more details on both options!

Launching Your Training Run
Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/docs/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Installation

Requirements
------------
- CUDA version >=12.4
- CUDA version 12.8
- `uv <https://docs.astral.sh/uv/>`_

We use `uv <https://docs.astral.sh/uv/>`_ to manage dependencies. We also make use of the `uv` and `ray` integration to manage dependencies for ray workers.
Expand All @@ -14,15 +14,15 @@ If you're running on an existing Ray cluster, make sure to use Ray 2.44.0 and Py
Docker (recommended)
---------------------

We provide a docker image with the base dependencies ``sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.4`` for quick setup.
We provide a docker image with the base dependencies ``sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.8`` for quick setup.

1. Make sure to have `NVIDIA Container Runtime <https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html>`_ installed.

2. You can launch the container using the following command:

.. code-block:: bash

docker run -it --runtime=nvidia --gpus all --name skyrl-train sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.4 /bin/bash
docker run -it --runtime=nvidia --gpus all --name skyrl-train sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.8 /bin/bash

3. Inside the launched container, setup the latest version of the project:

Expand All @@ -39,7 +39,7 @@ Install without Dockerfile

For installation without the Dockerfile, make sure you meet the pre-requisities:

- CUDA 12.4
- CUDA 12.8
- `uv <https://docs.astral.sh/uv/>`_
- `ray <https://docs.ray.io/en/latest/>`_ 2.44.0

Expand Down
6 changes: 5 additions & 1 deletion skyrl-train/examples/remote_inference_engine/run_remote.sh
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,17 @@ set -x

DATA_DIR="$HOME/data/gsm8k"

BACKEND="vllm" # or "sglang"
TP=4

uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
data.train_data="['$DATA_DIR/train.parquet']" \
data.val_data="['$DATA_DIR/validation.parquet']" \
trainer.policy.model.path="Qwen/Qwen2.5-1.5B-Instruct" \
generator.run_engines_locally=False \
generator.remote_inference_engine_urls="['127.0.0.1:8001']" \
generator.override_existing_update_group=True \
generator.inference_engine_tensor_parallel_size="$TP" \
generator.backend="$BACKEND" \
generator.sampling_params.temperature=0.6 \
generator.sampling_params.top_p=0.95 \
trainer.algorithm.advantage_estimator="grpo" \
Expand Down
12 changes: 12 additions & 0 deletions skyrl-train/examples/remote_inference_engine/run_sglang_server.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
# Launches sglang server for Qwen2.5-1.5B-Instruct on 4 GPUs.
# bash examples/remote_inference_engine/run_sglang_server.sh
set -x

CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --isolated --extra sglang -m \
skyrl_train.inference_engines.sglang.sglang_server \
--model-path Qwen/Qwen2.5-1.5B-Instruct \
--tp 4 \
--host 127.0.0.1 \
--port 8001 \
--context-length 4096 \
--dtype bfloat16
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
# bash examples/remote_inference_engine/run_vllm_server.sh
set -x

uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server \
# NOTE (sumanthrh): Currently, there's an issue with distributed executor backend ray for vllm 0.9.2.
# For standalone server, we use mp for now.
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This fails rn with vllm 0.9.2:

uv run --isolated --extra vllm -m  vllm.entrypoints.api_server     --model Qwen/Qwen2.5-1.5B-Instruct     --tensor-parallel-size 4     --host 127.0.0.1     --port 8001     --seed 42     --max-model-len 4096     --enable-prefix-caching     --enable-chunked-prefill     --dtype bfloat16     --gpu-memory-utilization 0.9     --enable-sleep-mode     --max-num_batched_tokens 8192     --max-num-seqs 1024     --trust-remote-code     --distributed-executor-backend ray

There's some issue here that we can dig into later. as such the remote server can use any backend so I'm just using mp backend

--model Qwen/Qwen2.5-1.5B-Instruct \
--tensor-parallel-size 4 \
--host 127.0.0.1 \
Expand All @@ -17,5 +19,5 @@ uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server
--max-num_batched_tokens 8192 \
--max-num-seqs 1024 \
--trust-remote-code \
--distributed-executor-backend ray \
--distributed-executor-backend mp \
--worker-extension-cls skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap
1 change: 1 addition & 0 deletions skyrl-train/examples/search/run_search.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ uv run --isolated --frozen --extra vllm -m skyrl_train.entrypoints.main_base \
generator.sampling_params.max_generate_length=500 \
generator.async_engine=true \
generator.batched=false \
generator.use_conversation_multi_turn=false \
generator.n_samples_per_prompt=5 \
generator.max_turns=4 \
generator.use_conversation_multi_turn=false \
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/examples/text_to_sql/run_skyrl_sql.sh
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.async_engine=true \
generator.batched=false \
environment.env_class=text2sql \
generator.use_conversation_multi_turn=false \
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since #65 has landed, the same fix needs to be made for the search example

generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=6 \
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/examples/text_to_sql/run_sql_deepspeed.sh
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ uv run --isolated --frozen --extra vllm --extra deepspeed -m skyrl_train.entrypo
generator.async_engine=true \
generator.batched=false \
environment.env_class=text2sql \
generator.use_conversation_multi_turn=false \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=5 \
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/examples/text_to_sql/run_sql_fsdp.sh
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.async_engine=true \
generator.batched=false \
environment.env_class=text2sql \
generator.use_conversation_multi_turn=false \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=6 \
Expand Down
1 change: 1 addition & 0 deletions skyrl-train/examples/text_to_sql/run_sql_fsdp_2node.sh
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \
generator.async_engine=true \
generator.batched=false \
environment.env_class=text2sql \
generator.use_conversation_multi_turn=false \
generator.n_samples_per_prompt=5 \
generator.gpu_memory_utilization=0.7 \
generator.max_turns=5 \
Expand Down
22 changes: 16 additions & 6 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]

dependencies = [
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl",
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl",
"loguru",
"tqdm",
"tensorboard",
Expand Down Expand Up @@ -51,6 +51,8 @@ conflicts = [

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }
torchvision = { index = "pytorch-cu128" }

[project.optional-dependencies]
deepspeed = [
Expand All @@ -72,20 +74,28 @@ docs = [
"sphinx-autobuild>=2021.3.14"
]
vllm = [
"vllm==0.8.5",
"vllm==0.9.2",
# NOTE (sumanthrh): We explictly use a flashinfer wheel from their index.
# The wheels on PyPI don't come with pre-compiled kernels and the package will JIT compile them at runtime (terribly slow).
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be",
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl",
"torch==2.7.0",
"torchvision"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added explicit torch and torchvision dependencies

Torch versions need not be specified but it's kinda helpful to know what we're dealing with. Does not hurt since uv dependency resolution will catch issues.

the main reason is being able to use torch compiled for cu12.8. By default, the version used is 12.6

]
sglang = [
"sglang[srt,openai]==0.4.6.post4",
"torch-memory-saver>=0.0.5",
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1",
# The version is pinned to 0.2.5 because sglang requires this
# NOTE (sumanthrh): This can be made a common dependency, but then different inference engines can pin different compatible flashinfer versions and it might quickly break.
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be",
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl",
"torch==2.7.1",
"torchvision",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.setuptools.packages.find]
include = ["skyrl_train*"]

Expand Down
2 changes: 1 addition & 1 deletion skyrl-train/skyrl_train/config/ppo_base_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ generator:
# whether to use a conversation based format for multi-turn generations
# if false, append multi-turn model responses and env observations to the original assistant response
# if true, each multi-turn model response and env observations is stored in a separate assistant/user message respectively
use_conversation_multi_turn: false
use_conversation_multi_turn: true
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be true by default. This is because the standard way should be using multi-turn conversations for observation. We had noticed some perf degradation for qwen 7B models if be used true instead of false, but that seems model specific.

sglang has also some issues with the flag -I think their /completions endpoint behaves differently in some way, because of which I got empty responses.

I will make sure to add this caveat with sglang to the docs as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am tracking this and other sglang feature support tasks here: #82

If there's more to say about this issue and what you observed, could you dump it into the issue?

Copy link
Member Author

@SumanthRH SumanthRH Jul 14, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes sg. Let me in fact dive into this separately. The main issue is probably due to some difference in the implementation of the /v1/completions API in sglang vs vllm

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The empty response comes up due to our defaults (which might have to change as well)


# sampling params for evaluation
eval_sampling_params:
Expand Down
3 changes: 1 addition & 2 deletions skyrl-train/skyrl_train/inference_engines/base.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from abc import ABC, abstractmethod
from typing import List, Dict, TypedDict, Any, Optional, Hashable
import torch

MessageType = Dict[str, str]
ConversationType = List[MessageType]
Expand All @@ -21,7 +20,7 @@ class InferenceEngineOutput(TypedDict):

class NamedWeightUpdateRequest(TypedDict):
name: str
dtype: torch.dtype
dtype: str
shape: List[int]
extras: Optional[Dict[str, Any]]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
InferenceEngineOutput,
NamedWeightUpdateRequest,
)
from skyrl_train.utils import torch_dtype_to_str
from typing import List, Optional, Dict, Any
import json
import asyncio
Expand Down Expand Up @@ -107,7 +106,6 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest):
raise ValueError(
"Remote inference engines do not support CUDA IPC weight updates. Only local engines support IPC."
)

if self.engine_backend == "vllm":
weight_update_method = "update_weight"
elif self.engine_backend == "sglang":
Expand All @@ -120,7 +118,7 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest):
f"{self.url}/{weight_update_method}",
json={
"name": request["name"],
"dtype": torch_dtype_to_str(request["dtype"]),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

minor edit here: since the inference engines expect this to be a string anyways, I have changed the datatype to be string for consistency

"dtype": request["dtype"],
"shape": request["shape"],
},
)
Expand Down
13 changes: 0 additions & 13 deletions skyrl-train/skyrl_train/inference_engines/sglang/sglang.patch

This file was deleted.

17 changes: 0 additions & 17 deletions skyrl-train/skyrl_train/inference_engines/sglang/sglang_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,6 @@
from sglang.srt.entrypoints.http_server import launch_server
from sglang.srt.server_args import prepare_server_args, ServerArgs
from sglang.srt.utils import kill_process_tree
import sglang
from pathlib import Path
import subprocess

PATCH_FILE_PATH = Path(__file__).parent / "sglang.patch"


def apply_patch():
try:
sglang_path = Path(sglang.__file__).parent
subprocess.run(
["patch", "-p1", "-d", str(sglang_path), "-i", str(PATCH_FILE_PATH), "--batch", "--forward"], check=True
)
except Exception as e:
print(f"Failed to patch sglang: {e}", file=sys.stderr)
sys.exit(1)


class SGLangServer:
Expand All @@ -34,7 +18,6 @@ def run_server(self) -> None:


if __name__ == "__main__":
apply_patch()
server_args = prepare_server_args(sys.argv[1:])
sglang_server = SGLangServer(server_args)
sglang_server.run_server()
12 changes: 6 additions & 6 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Any
from typing import List, Any, Dict
import ray
import torch
import asyncio
Expand All @@ -16,6 +16,7 @@
InferenceEngineOutput,
NamedWeightUpdateRequest,
)
from skyrl_train.utils import str_to_torch_dtype


def setup_envvars_for_vllm(kwargs, bundle_indices):
Expand Down Expand Up @@ -86,10 +87,9 @@ def init_weight_update_communicator(
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)

def update_weight(self, name, dtype, shape):
import torch

def update_weight(self, name: str, dtype: str, shape: List[int]):
"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
dtype = str_to_torch_dtype(dtype)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

vllm 0.9.2 change AFAIK. Previously we passed string direclty here but now model_config.dtype is of type torch.dtype

assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
Expand All @@ -98,9 +98,9 @@ def update_weight(self, name, dtype, shape):

del weight

def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None):
import torch
def update_weight_cuda_ipc(self, name: str, dtype: str, shape: List[int], ipc_handles: Dict[str, Any]):

dtype = str_to_torch_dtype(dtype)
device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
physical_gpu_id = str(props.uuid)
Expand Down
Loading