-
Notifications
You must be signed in to change notification settings - Fork 88
[Dependencies] Upgrade to torch 2.7 #73
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
e0d12fc
4c2532e
327592d
2bdf387
8f538d7
276e01f
4df8a73
5d4a0cf
d0897eb
d8dae04
40d389f
82f48d1
917aa58
8027c87
b6531dd
9e76064
a633159
d5f5c96
d3b21d5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,17 +1,24 @@ | ||
FROM anyscale/ray:2.44.0-slim-py312-cu124 | ||
FROM anyscale/ray:2.44.0-slim-py312-cu128 | ||
|
||
RUN sudo apt-get update -y && sudo apt-get install -y wget kmod libxml2 build-essential libnuma-dev | ||
|
||
# the cuda compiler here is needed for deepspeed | ||
RUN wget https://developer.download.nvidia.com/compute/cuda/12.4.0/local_installers/cuda_12.4.0_550.54.14_linux.run | ||
RUN sudo sh cuda_12.4.0_550.54.14_linux.run --silent --toolkit | ||
RUN wget https://developer.download.nvidia.com/compute/cuda/12.8.0/local_installers/cuda_12.8.0_570.86.10_linux.run \ | ||
&& sudo sh cuda_12.8.0_570.86.10_linux.run --silent --toolkit && rm -rf cuda_12.8.0_570.86.10_linux.run | ||
|
||
RUN curl -LsSf https://astral.sh/uv/install.sh | sh | ||
RUN echo "export RAY_RUNTIME_ENV_HOOK=ray._private.runtime_env.uv_runtime_env_hook.hook" >> /home/ray/.bashrc | ||
|
||
RUN sudo apt-get update \ | ||
&& sudo apt-get install -y openssh-server iputils-ping net-tools iproute2 traceroute netcat \ | ||
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata | ||
RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk | ||
libopenexr-dev libxi-dev libglfw3-dev libglew-dev libomp-dev libxinerama-dev libxcursor-dev tzdata \ | ||
&& sudo apt-get clean && sudo rm -rf /var/lib/apt/lists/* | ||
|
||
RUN sudo apt update && sudo apt install --fix-broken && sudo apt install -y default-jre-headless openjdk-8-jdk \ | ||
&& sudo apt-get clean \ | ||
&& sudo rm -rf /var/lib/apt/lists/* | ||
|
||
# NOTE: vllm installation in base environment is needed for uv + vLLM to work | ||
RUN pip install vllm==0.8.5 | ||
RUN pip install ray==2.44.0 | ||
RUN pip install vllm==0.9.2 \ | ||
&& pip install ray==2.44.0 \ | ||
&& rm -rf ~/.cache/pip |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
# Launches sglang server for Qwen2.5-1.5B-Instruct on 4 GPUs. | ||
# bash examples/remote_inference_engine/run_sglang_server.sh | ||
set -x | ||
|
||
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --isolated --extra sglang -m \ | ||
skyrl_train.inference_engines.sglang.sglang_server \ | ||
--model-path Qwen/Qwen2.5-1.5B-Instruct \ | ||
--tp 4 \ | ||
--host 127.0.0.1 \ | ||
--port 8001 \ | ||
--context-length 4096 \ | ||
--dtype bfloat16 |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -2,7 +2,9 @@ | |
# bash examples/remote_inference_engine/run_vllm_server.sh | ||
set -x | ||
|
||
uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server \ | ||
# NOTE (sumanthrh): Currently, there's an issue with distributed executor backend ray for vllm 0.9.2. | ||
# For standalone server, we use mp for now. | ||
CUDA_VISIBLE_DEVICES=4,5,6,7 uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This fails rn with vllm 0.9.2: uv run --isolated --extra vllm -m vllm.entrypoints.api_server --model Qwen/Qwen2.5-1.5B-Instruct --tensor-parallel-size 4 --host 127.0.0.1 --port 8001 --seed 42 --max-model-len 4096 --enable-prefix-caching --enable-chunked-prefill --dtype bfloat16 --gpu-memory-utilization 0.9 --enable-sleep-mode --max-num_batched_tokens 8192 --max-num-seqs 1024 --trust-remote-code --distributed-executor-backend ray There's some issue here that we can dig into later. as such the remote server can use any backend so I'm just using mp backend |
||
--model Qwen/Qwen2.5-1.5B-Instruct \ | ||
--tensor-parallel-size 4 \ | ||
--host 127.0.0.1 \ | ||
|
@@ -17,5 +19,5 @@ uv run --isolated --extra vllm -m skyrl_train.inference_engines.vllm.vllm_server | |
--max-num_batched_tokens 8192 \ | ||
--max-num-seqs 1024 \ | ||
--trust-remote-code \ | ||
--distributed-executor-backend ray \ | ||
--distributed-executor-backend mp \ | ||
--worker-extension-cls skyrl_train.inference_engines.vllm.vllm_engine.WorkerWrap |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -52,6 +52,7 @@ uv run --isolated --extra vllm -m skyrl_train.entrypoints.main_base \ | |
generator.async_engine=true \ | ||
generator.batched=false \ | ||
environment.env_class=text2sql \ | ||
generator.use_conversation_multi_turn=false \ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Since #65 has landed, the same fix needs to be made for the search example |
||
generator.n_samples_per_prompt=5 \ | ||
generator.gpu_memory_utilization=0.7 \ | ||
generator.max_turns=6 \ | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -19,7 +19,7 @@ classifiers = [ | |
] | ||
|
||
dependencies = [ | ||
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiFALSE-cp312-cp312-linux_x86_64.whl", | ||
"flash-attn@https://github.com/Dao-AILab/flash-attention/releases/download/v2.8.0.post2/flash_attn-2.8.0.post2+cu12torch2.7cxx11abiFALSE-cp312-cp312-linux_x86_64.whl", | ||
"loguru", | ||
"tqdm", | ||
"tensorboard", | ||
|
@@ -51,6 +51,8 @@ conflicts = [ | |
|
||
[tool.uv.sources] | ||
skyrl-gym = { path = "./skyrl-gym" , editable = true } | ||
torch = { index = "pytorch-cu128" } | ||
torchvision = { index = "pytorch-cu128" } | ||
|
||
[project.optional-dependencies] | ||
deepspeed = [ | ||
|
@@ -72,20 +74,28 @@ docs = [ | |
"sphinx-autobuild>=2021.3.14" | ||
] | ||
vllm = [ | ||
"vllm==0.8.5", | ||
"vllm==0.9.2", | ||
# NOTE (sumanthrh): We explictly use a flashinfer wheel from their index. | ||
# The wheels on PyPI don't come with pre-compiled kernels and the package will JIT compile them at runtime (terribly slow). | ||
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be", | ||
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", | ||
"torch==2.7.0", | ||
"torchvision" | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Added explicit torch and torchvision dependencies Torch versions need not be specified but it's kinda helpful to know what we're dealing with. Does not hurt since the main reason is being able to use torch compiled for cu12.8. By default, the version used is 12.6 |
||
] | ||
sglang = [ | ||
"sglang[srt,openai]==0.4.6.post4", | ||
"torch-memory-saver>=0.0.5", | ||
"sglang[srt,openai,torch_memory_saver]==0.4.8.post1", | ||
# The version is pinned to 0.2.5 because sglang requires this | ||
# NOTE (sumanthrh): This can be made a common dependency, but then different inference engines can pin different compatible flashinfer versions and it might quickly break. | ||
"flashinfer-python@https://github.com/flashinfer-ai/flashinfer/releases/download/v0.2.5/flashinfer_python-0.2.5+cu124torch2.6-cp38-abi3-linux_x86_64.whl#sha256=43d767b912c0c43a04be99595e0123eab9385fc72530a2874b5fb08e3145c0be", | ||
"flashinfer-python@https://download.pytorch.org/whl/cu128/flashinfer/flashinfer_python-0.2.6.post1%2Bcu128torch2.7-cp39-abi3-linux_x86_64.whl", | ||
"torch==2.7.1", | ||
"torchvision", | ||
] | ||
|
||
|
||
[[tool.uv.index]] | ||
name = "pytorch-cu128" | ||
url = "https://download.pytorch.org/whl/cu128" | ||
explicit = true | ||
|
||
[tool.setuptools.packages.find] | ||
include = ["skyrl_train*"] | ||
|
||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -172,7 +172,7 @@ generator: | |
# whether to use a conversation based format for multi-turn generations | ||
# if false, append multi-turn model responses and env observations to the original assistant response | ||
# if true, each multi-turn model response and env observations is stored in a separate assistant/user message respectively | ||
use_conversation_multi_turn: false | ||
use_conversation_multi_turn: true | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This should be sglang has also some issues with the flag -I think their I will make sure to add this caveat with sglang to the docs as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I am tracking this and other sglang feature support tasks here: #82 If there's more to say about this issue and what you observed, could you dump it into the issue? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes sg. Let me in fact dive into this separately. The main issue is probably due to some difference in the implementation of the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The empty response comes up due to our defaults (which might have to change as well) |
||
|
||
# sampling params for evaluation | ||
eval_sampling_params: | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -5,7 +5,6 @@ | |
InferenceEngineOutput, | ||
NamedWeightUpdateRequest, | ||
) | ||
from skyrl_train.utils import torch_dtype_to_str | ||
from typing import List, Optional, Dict, Any | ||
import json | ||
import asyncio | ||
|
@@ -107,7 +106,6 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest): | |
raise ValueError( | ||
"Remote inference engines do not support CUDA IPC weight updates. Only local engines support IPC." | ||
) | ||
|
||
if self.engine_backend == "vllm": | ||
weight_update_method = "update_weight" | ||
elif self.engine_backend == "sglang": | ||
|
@@ -120,7 +118,7 @@ async def update_named_weight(self, request: NamedWeightUpdateRequest): | |
f"{self.url}/{weight_update_method}", | ||
json={ | ||
"name": request["name"], | ||
"dtype": torch_dtype_to_str(request["dtype"]), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. minor edit here: since the inference engines expect this to be a string anyways, I have changed the datatype to be string for consistency |
||
"dtype": request["dtype"], | ||
"shape": request["shape"], | ||
}, | ||
) | ||
|
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,5 @@ | ||
import os | ||
from typing import List, Any | ||
from typing import List, Any, Dict | ||
import ray | ||
import torch | ||
import asyncio | ||
|
@@ -16,6 +16,7 @@ | |
InferenceEngineOutput, | ||
NamedWeightUpdateRequest, | ||
) | ||
from skyrl_train.utils import str_to_torch_dtype | ||
|
||
|
||
def setup_envvars_for_vllm(kwargs, bundle_indices): | ||
|
@@ -86,10 +87,9 @@ def init_weight_update_communicator( | |
f"rank={rank}, world_size={world_size}, group_name={group_name}", | ||
) | ||
|
||
def update_weight(self, name, dtype, shape): | ||
import torch | ||
|
||
def update_weight(self, name: str, dtype: str, shape: List[int]): | ||
"""Broadcast weight to all vllm workers from source rank 0 (actor model)""" | ||
dtype = str_to_torch_dtype(dtype) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. vllm 0.9.2 change AFAIK. Previously we passed string direclty here but now |
||
assert dtype == self.model_config.dtype, f"mismatch dtype: src {dtype}, dst {self.model_config.dtype}" | ||
weight = torch.empty(shape, dtype=dtype, device="cuda") | ||
torch.distributed.broadcast(weight, 0, group=self._model_update_group) | ||
|
@@ -98,9 +98,9 @@ def update_weight(self, name, dtype, shape): | |
|
||
del weight | ||
|
||
def update_weight_cuda_ipc(self, name, dtype, shape, ipc_handles=None): | ||
import torch | ||
def update_weight_cuda_ipc(self, name: str, dtype: str, shape: List[int], ipc_handles: Dict[str, Any]): | ||
|
||
dtype = str_to_torch_dtype(dtype) | ||
device = torch.cuda.current_device() | ||
props = torch.cuda.get_device_properties(device) | ||
physical_gpu_id = str(props.uuid) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] Multiple
RUN apt-get update
and cleanup layers can be consolidated into a singleRUN
block to reduce image layers and overall size.Copilot uses AI. Check for mistakes.