Skip to content

Commit 6630b7b

Browse files
corbtclaude
andcommitted
Fix client import error by vendoring transformers constants
This fixes issue #230 where importing art.Model on the client side failed due to missing transformers dependencies. The issue was introduced in PR #194 which moved backend dependencies to an optional group but didn't account for art.dev importing transformers at the module level. The fix vendors all transformers enum constants into art/dev/model.py with source comments, removing the hard dependency on transformers for client imports while maintaining full compatibility. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 786900a commit 6630b7b

File tree

14 files changed

+1913
-146
lines changed

14 files changed

+1913
-146
lines changed

dev/profile.ipynb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
"metadata": {},
1717
"outputs": [],
1818
"source": [
19-
"from art.dev.model import get_model_config\n",
19+
"from art.dev.get_model_config import get_model_config\n",
2020
"from art.local.state import ModelState\n",
2121
"\n",
2222
"config = get_model_config(\n",
@@ -188,4 +188,4 @@
188188
},
189189
"nbformat": 4,
190190
"nbformat_minor": 2
191-
}
191+
}

src/art/cli.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from fastapi import FastAPI, Body
22
from fastapi.responses import StreamingResponse
3+
from fastapi import Request
4+
from fastapi.responses import JSONResponse
35
import json
46
import pydantic
57
import socket
@@ -13,6 +15,7 @@
1315
from .trajectories import TrajectoryGroup
1416
from .types import TrainConfig
1517
from .utils.deploy_model import LoRADeploymentProvider
18+
from .errors import ARTError
1619

1720
app = typer.Typer()
1821

@@ -44,6 +47,12 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
4447

4548
backend = LocalBackend()
4649
app = FastAPI()
50+
51+
# Add exception handler for ARTError
52+
@app.exception_handler(ARTError)
53+
async def art_error_handler(request: Request, exc: ARTError):
54+
return JSONResponse(status_code=exc.status_code, content={"detail": exc.detail})
55+
4756
app.get("/healthcheck")(lambda: {"status": "ok"})
4857
app.post("/close")(backend.close)
4958
app.post("/register")(backend.register)

src/art/dev/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .engine import EngineArgs
22
from .model import (
3-
get_model_config,
43
InternalModelConfig,
54
InitArgs,
65
PeftArgs,
@@ -12,7 +11,6 @@
1211

1312
__all__ = [
1413
"EngineArgs",
15-
"get_model_config",
1614
"InternalModelConfig",
1715
"InitArgs",
1816
"PeftArgs",

src/art/dev/get_model_config.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import torch
2+
3+
from .engine import EngineArgs
4+
from .model import InitArgs, PeftArgs, TrainerArgs
5+
from .torchtune import TorchtuneArgs
6+
from .model import InternalModelConfig
7+
8+
9+
def get_model_config(
10+
base_model: str,
11+
output_dir: str,
12+
config: "InternalModelConfig | None",
13+
) -> "InternalModelConfig":
14+
from ..local.checkpoints import get_last_checkpoint_dir
15+
16+
if config is None:
17+
config = InternalModelConfig()
18+
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
19+
init_args = InitArgs(
20+
model_name=base_model,
21+
max_seq_length=32768,
22+
load_in_4bit=True, # False for LoRA 16bit
23+
fast_inference=True, # Enable vLLM fast inference
24+
# vLLM args
25+
disable_log_stats=False,
26+
enable_prefix_caching=True,
27+
gpu_memory_utilization=(
28+
0.79 if enable_sleep_mode else 0.55
29+
), # Reduce if out of memory
30+
max_lora_rank=8,
31+
use_async=True,
32+
)
33+
if config.get("_decouple_vllm_and_unsloth", False):
34+
init_args["fast_inference"] = False
35+
init_args.pop("disable_log_stats")
36+
init_args.pop("enable_prefix_caching")
37+
init_args.pop("gpu_memory_utilization")
38+
init_args.pop("max_lora_rank")
39+
init_args.pop("use_async")
40+
engine_args = EngineArgs(
41+
disable_log_requests=True,
42+
# Multi-step processing is not supported for the Xformers attention backend
43+
# which is the fallback for devices with compute capability < 8.0
44+
num_scheduler_steps=(
45+
16
46+
if config.get("torchtune_args") is None
47+
and not config.get("_decouple_vllm_and_unsloth", False)
48+
and torch.cuda.get_device_capability()[0] >= 8
49+
else 1
50+
),
51+
enable_sleep_mode=enable_sleep_mode,
52+
generation_config="vllm",
53+
)
54+
engine_args.update(config.get("engine_args", {}))
55+
init_args.update(config.get("init_args", {}))
56+
if last_checkpoint_dir := get_last_checkpoint_dir(output_dir):
57+
init_args["model_name"] = last_checkpoint_dir
58+
if config.get("torchtune_args") is not None:
59+
engine_args["model"] = last_checkpoint_dir
60+
elif config.get("torchtune_args") is not None:
61+
engine_args["model"] = base_model
62+
if config.get("_decouple_vllm_and_unsloth", False):
63+
engine_args["model"] = base_model
64+
peft_args = PeftArgs(
65+
r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
66+
target_modules=[
67+
"q_proj",
68+
"k_proj",
69+
"v_proj",
70+
"o_proj",
71+
"gate_proj",
72+
"up_proj",
73+
"down_proj",
74+
], # Remove QKVO if out of memory
75+
lora_alpha=16,
76+
# Enable long context finetuning
77+
use_gradient_checkpointing="unsloth", # type: ignore
78+
random_state=3407,
79+
)
80+
peft_args.update(config.get("peft_args", {}))
81+
trainer_args = TrainerArgs(
82+
learning_rate=5e-6,
83+
adam_beta1=0.9,
84+
adam_beta2=0.99,
85+
weight_decay=0.1,
86+
lr_scheduler_type="constant",
87+
optim="paged_adamw_8bit",
88+
logging_steps=1,
89+
per_device_train_batch_size=2,
90+
gradient_accumulation_steps=1, # Increase to 4 for smoother training
91+
num_generations=2, # Decrease if out of memory
92+
max_grad_norm=0.1,
93+
save_strategy="no",
94+
output_dir=output_dir,
95+
disable_tqdm=True,
96+
report_to="none",
97+
)
98+
trainer_args.update(config.get("trainer_args", {}))
99+
if config.get("torchtune_args") is not None:
100+
torchtune_args = TorchtuneArgs(model="qwen3_32b", model_type="QWEN3")
101+
torchtune_args.update(config.get("torchtune_args", {}) or {})
102+
else:
103+
torchtune_args = None
104+
return InternalModelConfig(
105+
init_args=init_args,
106+
engine_args=engine_args,
107+
peft_args=peft_args,
108+
trainer_args=trainer_args,
109+
torchtune_args=torchtune_args,
110+
_decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False),
111+
)

0 commit comments

Comments
 (0)