Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 14 additions & 7 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,17 +1,24 @@
FROM anyscale/ray:2.44.0-slim-py312-cu124
FROM anyscale/ray:2.44.0-slim-py312-cu128

RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev

# the cuda compiler here is needed for deepspeed
RUN wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run
RUN sudo sh cuda_12.4.0_550.54.14_linux.run --silent --toolkit
RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \
&& sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run

RUN curl -LsSf https://astral.sh/uv/install.sh | sh
RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc

RUN sudo apt-get update \
&& sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata
RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/*

RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \
&& sudo apt-get clean \
&& sudo rm -rf /var/lib/apt/lists/*

Comment on lines 3 to +20
Copy link
Preview

Copilot AI Jul 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Multiple RUN apt-get update and cleanup layers can be consolidated into a single RUN block to reduce image layers and overall size.

Copilot uses AI. Check for mistakes.

# NOTE: vllm installation in base environment is needed for uv + vLLM to work
RUN pip install vllm==0.8.5
RUN pip install ray==2.44.0
RUN pip install vllm==0.9.2 \
&& pip install ray==2.44.0 \
&& rm -rf ~/.cache/pip
8 changes: 4 additions & 4 deletions skyrl-train/docs/getting-started/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Installation

Requirements
------------
- CUDA version >=12.4
- CUDA version >=12.4 (Recommended: 12.8)
- `uv <https://docs.astral.sh/uv/>`_

We use `uv <https://docs.astral.sh/uv/>`_ to manage dependencies. We also make use of the `uv` and `ray` integration to manage dependencies for ray workers.
Expand All @@ -14,15 +14,15 @@ If you're running on an existing Ray cluster, make sure to use Ray 2.44.0 and Py
Docker (recommended)
---------------------

We provide a docker image with the base dependencies ``sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.4`` for quick setup.
We provide a docker image with the base dependencies ``sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.8`` for quick setup.

1. Make sure to have `NVIDIA Container Runtime <https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html>`_ installed.

2. You can launch the container using the following command:

.. code-block:: bash

docker run -it --runtime=nvidia --gpus all --name skyrl-train sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.4 /bin/bash
docker run -it --runtime=nvidia --gpus all --name skyrl-train sumanthrh/skyrl-train-ray-2.44.0-py3.12-cu12.8 /bin/bash

3. Inside the launched container, setup the latest version of the project:

Expand All @@ -39,7 +39,7 @@ Install without Dockerfile

For installation without the Dockerfile, make sure you meet the pre-requisities:

- CUDA 12.4
- CUDA 12.4 (Recommended: 12.8)
- `uv <https://docs.astral.sh/uv/>`_
- `ray <https://docs.ray.io/en/latest/>`_ 2.44.0

Expand Down
17 changes: 12 additions & 5 deletions skyrl-train/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ classifiers = [
]

dependencies = [
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl",
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl",
"loguru",
"tqdm",
"tensorboard",
Expand Down Expand Up @@ -51,6 +51,7 @@ conflicts = [

[tool.uv.sources]
skyrl-gym = { path = "./skyrl-gym" , editable = true }
torch = { index = "pytorch-cu128" }

[project.optional-dependencies]
deepspeed = [
Expand All @@ -72,20 +73,26 @@ docs = [
"sphinx-autobuild>=2021.3.14"
]
vllm = [
"vllm==0.8.5",
"vllm==0.9.2",
# NOTE (sumanthrh): We explictly use a flashinfer wheel from their index.
# The wheels on PyPI don't come with pre-compiled kernels and the package will JIT compile them at runtime (terribly slow).
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be",
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl",
]
sglang = [
"sglang[srt,openai]==0.4.6.post4",
"sglang[srt,openai]==0.4.8.post1",
"torch-memory-saver>=0.0.5",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can do sglang[srt,openai,torch_memory_saver]==0.4.8.post1

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks!

# The version is pinned to 0.2.5 because sglang requires this
# NOTE (sumanthrh): This can be made a common dependency, but then different inference engines can pin different compatible flashinfer versions and it might quickly break.
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be",
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl",
"torch==2.7.1",
]


[[tool.uv.index]]
name = "pytorch-cu128"
url = "https://download.pytorch.org/whl/cu128"
explicit = true

[tool.setuptools.packages.find]
include = ["skyrl_train*"]

Expand Down
10 changes: 7 additions & 3 deletions skyrl-train/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List, Any
from typing import List, Any, Dict
import ray
import torch
import asyncio
Expand All @@ -16,6 +16,7 @@
InferenceEngineOutput,
NamedWeightUpdateRequest,
)
from skyrl_train.utils import str_to_torch_dtype


def setup_envvars_for_vllm(kwargs, bundle_indices):
Expand Down Expand Up @@ -86,10 +87,11 @@ def init_weight_update_communicator(
f"rank={rank}, world_size={world_size}, group_name={group_name}",
)

def update_weight(self, name, dtype, shape):
def update_weight(self, name: str, dtype: str, shape: List[int]):
import torch

"""Broadcast weight to all vllm workers from source rank 0 (actor model)"""
dtype: torch.dtype = str_to_torch_dtype(dtype)
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}"
weight = torch.empty(shape, dtype=dtype, device="cuda")
torch.distributed.broadcast(weight, 0, group=self._model_update_group)
Expand All @@ -98,9 +100,11 @@ def update_weight(self, name, dtype, shape):

del weight

def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None):
def update_weight_cuda_ipc(self, name: str, dtype: str, shape: List[int], ipc_handles: Dict[str, Any]):
import torch

dtype: torch.dtype = str_to_torch_dtype(dtype)

device = torch.cuda.current_device()
props = torch.cuda.get_device_properties(device)
physical_gpu_id = str(props.uuid)
Expand Down
17 changes: 11 additions & 6 deletions skyrl-train/skyrl_train/utils/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,12 +246,17 @@ def initialize_ray(cfg: DictConfig):
"NCCL_P2P_DISABLE": "0",
"CUDA_LAUNCH_BLOCKING": "1",
}
if cfg.generator.backend == "vllm" and not os.environ.get("VLLM_USE_V1", False):
logger.info(
"`VLLM_USE_V1` is not specified, setting `VLLM_USE_V1` to 1. To override, set `VLLM_USE_V1` explicitly"
)
env_vars["VLLM_USE_V1"] = "1"
env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"
if cfg.generator.backend == "vllm":
# NOTE (sumanthrh): In vllm >= 0.9.0, we need to explicitly allow for serialization via pickle for collective RPCs.
# During weight transfer, we use IPC handles, which contains a `function` object and requires pickling.
env_vars["VLLM_ALLOW_INSECURE_SERIALIZATION"] = "1"

if not os.environ.get("VLLM_USE_V1", False):
logger.info(
"`VLLM_USE_V1` is not specified, setting `VLLM_USE_V1` to 1. To override, set `VLLM_USE_V1` explicitly"
)
env_vars["VLLM_USE_V1"] = "1"
env_vars["VLLM_ENABLE_V1_MULTIPROCESSING"] = "0"

# TODO: this can be removed if we standardize on env files.
# But it's helpful for a quickstart
Expand Down
5 changes: 3 additions & 2 deletions skyrl-train/skyrl_train/workers/deepspeed/deepspeed_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
PolicyLoss,
ValueLoss,
)
from skyrl_train.utils import torch_dtype_to_str


class DeepSpeedPolicyWorkerBase(PolicyWorkerBase):
Expand Down Expand Up @@ -142,7 +143,7 @@ async def broadcast_to_inference_engines(self, inference_engine_client):
inference_engine_client.update_named_weight(
{
"name": name,
"dtype": generator_dtype,
"dtype": torch_dtype_to_str(generator_dtype),
"shape": shape,
}
)
Expand Down Expand Up @@ -184,7 +185,7 @@ def gather_and_broadcast(param):
inference_engine_client.update_named_weight(
{
"name": name,
"dtype": generator_dtype,
"dtype": torch_dtype_to_str(generator_dtype),
"shape": shape,
"extras": {
"ipc_handles": ipc_handles,
Expand Down
8 changes: 4 additions & 4 deletions skyrl-train/skyrl_train/workers/fsdp/fsdp_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from skyrl_train.models import Actor, get_llm_for_sequence_regression
from skyrl_train.distributed.fsdp_strategy import FSDPStrategy
from skyrl_train.utils import get_physical_gpu_id
from skyrl_train.utils.utils import str_to_torch_dtype
from skyrl_train.utils import get_physical_gpu_id, torch_dtype_to_str, str_to_torch_dtype
from skyrl_train.training_batch import TrainingInputBatch, TrainingOutputBatch
from skyrl_train.distributed.fsdp_utils import fsdp_version, get_init_weight_context_manager
from skyrl_train.workers.worker import (
Expand Down Expand Up @@ -127,7 +126,7 @@ async def broadcast_to_inference_engines(self, inference_engine_client):
inference_engine_client.update_named_weight(
{
"name": name,
"dtype": generator_dtype,
"dtype": torch_dtype_to_str(generator_dtype),
"shape": shape,
}
)
Expand Down Expand Up @@ -157,6 +156,7 @@ def gather_and_broadcast(param):
ipc_handle = reduce_tensor(weight)

ipc_handle = {get_physical_gpu_id(): ipc_handle}

ipc_handle_list = [None] * torch.distributed.get_world_size()
torch.distributed.all_gather_object(ipc_handle_list, ipc_handle)

Expand All @@ -171,7 +171,7 @@ def gather_and_broadcast(param):
inference_engine_client.update_named_weight(
{
"name": name,
"dtype": generator_dtype,
"dtype": torch_dtype_to_str(generator_dtype),
"shape": shape,
"extras": {
"ipc_handles": ipc_handles,
Expand Down
Loading