Skip to content

Commit d637626

Browse files
authored
feat: Add experimental vLLM & Unsloth decoupling support
1 parent 3b75c10 commit d637626

File tree

7 files changed

+365
-19
lines changed

7 files changed

+365
-19
lines changed

dev/yes-no-maybe.ipynb

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"cells": [
33
{
44
"cell_type": "code",
5-
"execution_count": 1,
5+
"execution_count": null,
66
"metadata": {},
77
"outputs": [],
88
"source": [
@@ -46,6 +46,10 @@
4646
" name=\"001\",\n",
4747
" project=\"yes-no-maybe\",\n",
4848
" base_model=\"Qwen/Qwen2.5-7B-Instruct\",\n",
49+
" _internal_config=art.dev.InternalModelConfig(\n",
50+
" _decouple_vllm_and_unsloth=True,\n",
51+
" engine_args=art.dev.EngineArgs(gpu_memory_utilization=0.7),\n",
52+
" ),\n",
4953
")\n",
5054
"await model.register(backend)\n",
5155
"\n",

src/art/__init__.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,13 @@
99
if os.environ.get("IMPORT_UNSLOTH", "0") == "1":
1010
import unsloth # type: ignore # noqa: F401
1111

12+
if os.environ.get("IMPORT_PEFT", "0") == "1":
13+
# torch.cuda.MemPool doesn't currently support expandable_segments which is used in sleep mode
14+
conf = os.environ["PYTORCH_CUDA_ALLOC_CONF"].split(",")
15+
if "expandable_segments:True" in conf:
16+
conf.remove("expandable_segments:True")
17+
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = ",".join(conf)
18+
1219
from . import dev
1320
from .backend import Backend
1421
from .batches import trajectory_group_batches

src/art/dev/model.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,13 +38,21 @@ def get_model_config(
3838
max_lora_rank=8,
3939
use_async=True,
4040
)
41+
if config.get("_decouple_vllm_and_unsloth", False):
42+
init_args["fast_inference"] = False
43+
init_args.pop("disable_log_stats")
44+
init_args.pop("enable_prefix_caching")
45+
init_args.pop("gpu_memory_utilization")
46+
init_args.pop("max_lora_rank")
47+
init_args.pop("use_async")
4148
engine_args = EngineArgs(
4249
disable_log_requests=True,
4350
# Multi-step processing is not supported for the Xformers attention backend
4451
# which is the fallback for devices with compute capability < 8.0
4552
num_scheduler_steps=(
4653
16
4754
if config.get("torchtune_args") is None
55+
and not config.get("_decouple_vllm_and_unsloth", False)
4856
and torch.cuda.get_device_capability()[0] >= 8
4957
else 1
5058
),
@@ -59,6 +67,8 @@ def get_model_config(
5967
engine_args["model"] = last_checkpoint_dir
6068
elif config.get("torchtune_args") is not None:
6169
engine_args["model"] = base_model
70+
if config.get("_decouple_vllm_and_unsloth", False):
71+
engine_args["model"] = base_model
6272
peft_args = PeftArgs(
6373
r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
6474
target_modules=[
@@ -105,6 +115,7 @@ def get_model_config(
105115
peft_args=peft_args,
106116
trainer_args=trainer_args,
107117
torchtune_args=torchtune_args,
118+
_decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False),
108119
)
109120

110121

@@ -123,6 +134,7 @@ class InternalModelConfig(TypedDict, total=False):
123134
peft_args: "PeftArgs"
124135
trainer_args: "TrainerArgs"
125136
torchtune_args: TorchtuneArgs | None
137+
_decouple_vllm_and_unsloth: bool
126138

127139

128140
class InitArgs(TypedDict, total=False):

src/art/local/backend.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -112,29 +112,33 @@ async def register(
112112
async def _get_service(self, model: TrainableModel) -> ModelService:
113113
from ..torchtune.service import TorchtuneService
114114
from ..unsloth.service import UnslothService
115+
from ..unsloth.decoupled_service import DecoupledUnslothService
115116

116117
if model.name not in self._services:
117118
config = dev.get_model_config(
118119
base_model=model.base_model,
119120
output_dir=get_model_dir(model=model, art_path=self._path),
120121
config=model._internal_config,
121122
)
122-
service_class = (
123-
TorchtuneService
124-
if config.get("torchtune_args") is not None
125-
else UnslothService
126-
)
123+
if config.get("torchtune_args") is not None:
124+
service_class = TorchtuneService
125+
elif config.get("_decouple_vllm_and_unsloth", False):
126+
service_class = DecoupledUnslothService
127+
else:
128+
service_class = UnslothService
127129
self._services[model.name] = service_class(
128130
model_name=model.name,
129131
base_model=model.base_model,
130132
config=config,
131133
output_dir=get_model_dir(model=model, art_path=self._path),
132134
)
133-
134135
if not self._in_process:
135136
# Kill all "model-service" processes to free up GPU memory
136137
subprocess.run(["pkill", "-9", "model-service"])
137-
if isinstance(self._services[model.name], UnslothService):
138+
if isinstance(
139+
self._services[model.name],
140+
(UnslothService, DecoupledUnslothService),
141+
):
138142
# To enable sleep mode, import peft before unsloth
139143
# Unsloth will issue warnings, but everything appears to be okay
140144
if config.get("engine_args", {}).get("enable_sleep_mode", False):

0 commit comments

Comments
 (0)