Skip to content

Commit 9339fec

Browse files
corbtclaude
andcommitted
Fix client import error by vendoring transformers constants
This fixes issue #230 where importing art.Model on the client side failed due to missing transformers dependencies. The issue was introduced in PR #194 which moved backend dependencies to an optional group but didn't account for art.dev importing transformers at the module level. The fix vendors all transformers enum constants into art/dev/model.py with source comments, removing the hard dependency on transformers for client imports while maintaining full compatibility. 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <[email protected]>
1 parent 786900a commit 9339fec

File tree

4 files changed

+212
-116
lines changed

4 files changed

+212
-116
lines changed

src/art/dev/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
from .engine import EngineArgs
22
from .model import (
3-
get_model_config,
43
InternalModelConfig,
54
InitArgs,
65
PeftArgs,
@@ -12,7 +11,6 @@
1211

1312
__all__ = [
1413
"EngineArgs",
15-
"get_model_config",
1614
"InternalModelConfig",
1715
"InitArgs",
1816
"PeftArgs",

src/art/dev/get_model_config.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
import torch
2+
3+
from .engine import EngineArgs
4+
from .model import InitArgs, PeftArgs, TrainerArgs
5+
from .torchtune import TorchtuneArgs
6+
from .model import InternalModelConfig
7+
8+
9+
def get_model_config(
10+
base_model: str,
11+
output_dir: str,
12+
config: "InternalModelConfig | None",
13+
) -> "InternalModelConfig":
14+
from ..local.checkpoints import get_last_checkpoint_dir
15+
16+
if config is None:
17+
config = InternalModelConfig()
18+
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
19+
init_args = InitArgs(
20+
model_name=base_model,
21+
max_seq_length=32768,
22+
load_in_4bit=True, # False for LoRA 16bit
23+
fast_inference=True, # Enable vLLM fast inference
24+
# vLLM args
25+
disable_log_stats=False,
26+
enable_prefix_caching=True,
27+
gpu_memory_utilization=(
28+
0.79 if enable_sleep_mode else 0.55
29+
), # Reduce if out of memory
30+
max_lora_rank=8,
31+
use_async=True,
32+
)
33+
if config.get("_decouple_vllm_and_unsloth", False):
34+
init_args["fast_inference"] = False
35+
init_args.pop("disable_log_stats")
36+
init_args.pop("enable_prefix_caching")
37+
init_args.pop("gpu_memory_utilization")
38+
init_args.pop("max_lora_rank")
39+
init_args.pop("use_async")
40+
engine_args = EngineArgs(
41+
disable_log_requests=True,
42+
# Multi-step processing is not supported for the Xformers attention backend
43+
# which is the fallback for devices with compute capability < 8.0
44+
num_scheduler_steps=(
45+
16
46+
if config.get("torchtune_args") is None
47+
and not config.get("_decouple_vllm_and_unsloth", False)
48+
and torch.cuda.get_device_capability()[0] >= 8
49+
else 1
50+
),
51+
enable_sleep_mode=enable_sleep_mode,
52+
generation_config="vllm",
53+
)
54+
engine_args.update(config.get("engine_args", {}))
55+
init_args.update(config.get("init_args", {}))
56+
if last_checkpoint_dir := get_last_checkpoint_dir(output_dir):
57+
init_args["model_name"] = last_checkpoint_dir
58+
if config.get("torchtune_args") is not None:
59+
engine_args["model"] = last_checkpoint_dir
60+
elif config.get("torchtune_args") is not None:
61+
engine_args["model"] = base_model
62+
if config.get("_decouple_vllm_and_unsloth", False):
63+
engine_args["model"] = base_model
64+
peft_args = PeftArgs(
65+
r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
66+
target_modules=[
67+
"q_proj",
68+
"k_proj",
69+
"v_proj",
70+
"o_proj",
71+
"gate_proj",
72+
"up_proj",
73+
"down_proj",
74+
], # Remove QKVO if out of memory
75+
lora_alpha=16,
76+
# Enable long context finetuning
77+
use_gradient_checkpointing="unsloth", # type: ignore
78+
random_state=3407,
79+
)
80+
peft_args.update(config.get("peft_args", {}))
81+
trainer_args = TrainerArgs(
82+
learning_rate=5e-6,
83+
adam_beta1=0.9,
84+
adam_beta2=0.99,
85+
weight_decay=0.1,
86+
lr_scheduler_type="constant",
87+
optim="paged_adamw_8bit",
88+
logging_steps=1,
89+
per_device_train_batch_size=2,
90+
gradient_accumulation_steps=1, # Increase to 4 for smoother training
91+
num_generations=2, # Decrease if out of memory
92+
max_grad_norm=0.1,
93+
save_strategy="no",
94+
output_dir=output_dir,
95+
disable_tqdm=True,
96+
report_to="none",
97+
)
98+
trainer_args.update(config.get("trainer_args", {}))
99+
if config.get("torchtune_args") is not None:
100+
torchtune_args = TorchtuneArgs(model="qwen3_32b", model_type="QWEN3")
101+
torchtune_args.update(config.get("torchtune_args", {}) or {})
102+
else:
103+
torchtune_args = None
104+
return InternalModelConfig(
105+
init_args=init_args,
106+
engine_args=engine_args,
107+
peft_args=peft_args,
108+
trainer_args=trainer_args,
109+
torchtune_args=torchtune_args,
110+
_decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False),
111+
)

src/art/dev/model.py

Lines changed: 99 additions & 113 deletions
Original file line numberDiff line numberDiff line change
@@ -1,122 +1,108 @@
1-
import torch
2-
from transformers.debug_utils import DebugOption
3-
from transformers.training_args import OptimizerNames
4-
from transformers.trainer_utils import (
5-
FSDPOption,
6-
HubStrategy,
7-
IntervalStrategy,
8-
SaveStrategy,
9-
SchedulerType,
10-
)
1+
from enum import Enum
112
from typing_extensions import TypedDict
123

134
from .engine import EngineArgs
145
from .torchtune import TorchtuneArgs
156

167

17-
def get_model_config(
18-
base_model: str,
19-
output_dir: str,
20-
config: "InternalModelConfig | None",
21-
) -> "InternalModelConfig":
22-
from ..local.checkpoints import get_last_checkpoint_dir
8+
# Vendored from transformers.training_args.OptimizerNames
9+
class OptimizerNames(str, Enum):
10+
"""
11+
Stores the acceptable string identifiers for optimizers.
12+
"""
13+
14+
ADAMW_TORCH = "adamw_torch"
15+
ADAMW_TORCH_FUSED = "adamw_torch_fused"
16+
ADAMW_TORCH_XLA = "adamw_torch_xla"
17+
ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused"
18+
ADAMW_APEX_FUSED = "adamw_apex_fused"
19+
ADAFACTOR = "adafactor"
20+
ADAMW_ANYPRECISION = "adamw_anyprecision"
21+
ADAMW_TORCH_4BIT = "adamw_torch_4bit"
22+
ADAMW_TORCH_8BIT = "adamw_torch_8bit"
23+
ADEMAMIX = "ademamix"
24+
SGD = "sgd"
25+
ADAGRAD = "adagrad"
26+
ADAMW_BNB = "adamw_bnb_8bit"
27+
ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit
28+
ADEMAMIX_8BIT = "ademamix_8bit"
29+
LION_8BIT = "lion_8bit"
30+
LION = "lion_32bit"
31+
PAGED_ADAMW = "paged_adamw_32bit"
32+
PAGED_ADAMW_8BIT = "paged_adamw_8bit"
33+
PAGED_ADEMAMIX = "paged_ademamix_32bit"
34+
PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit"
35+
PAGED_LION = "paged_lion_32bit"
36+
PAGED_LION_8BIT = "paged_lion_8bit"
37+
RMSPROP = "rmsprop"
38+
RMSPROP_BNB = "rmsprop_bnb"
39+
RMSPROP_8BIT = "rmsprop_bnb_8bit"
40+
RMSPROP_32BIT = "rmsprop_bnb_32bit"
41+
GALORE_ADAMW = "galore_adamw"
42+
GALORE_ADAMW_8BIT = "galore_adamw_8bit"
43+
GALORE_ADAFACTOR = "galore_adafactor"
44+
GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise"
45+
GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise"
46+
GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise"
47+
LOMO = "lomo"
48+
ADALOMO = "adalomo"
49+
GROKADAMW = "grokadamw"
50+
SCHEDULE_FREE_RADAM = "schedule_free_radam"
51+
SCHEDULE_FREE_ADAMW = "schedule_free_adamw"
52+
SCHEDULE_FREE_SGD = "schedule_free_sgd"
53+
APOLLO_ADAMW = "apollo_adamw"
54+
APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise"
55+
56+
57+
# Vendored from transformers.debug_utils.DebugOption
58+
class DebugOption(str, Enum):
59+
UNDERFLOW_OVERFLOW = "underflow_overflow"
60+
TPU_METRICS_DEBUG = "tpu_metrics_debug"
61+
62+
63+
# Vendored from transformers.trainer_utils.IntervalStrategy
64+
class IntervalStrategy(str, Enum):
65+
NO = "no"
66+
STEPS = "steps"
67+
EPOCH = "epoch"
68+
69+
70+
# Vendored from transformers.trainer_utils.SaveStrategy (which is an alias for IntervalStrategy)
71+
SaveStrategy = IntervalStrategy
72+
73+
74+
# Vendored from transformers.trainer_utils.HubStrategy
75+
class HubStrategy(str, Enum):
76+
END = "end"
77+
EVERY_SAVE = "every_save"
78+
CHECKPOINT = "checkpoint"
79+
ALL_CHECKPOINTS = "all_checkpoints"
80+
81+
82+
# Vendored from transformers.trainer_utils.SchedulerType
83+
class SchedulerType(str, Enum):
84+
LINEAR = "linear"
85+
COSINE = "cosine"
86+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
87+
POLYNOMIAL = "polynomial"
88+
CONSTANT = "constant"
89+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
90+
INVERSE_SQRT = "inverse_sqrt"
91+
REDUCE_ON_PLATEAU = "reduce_lr_on_plateau"
92+
COSINE_WITH_MIN_LR = "cosine_with_min_lr"
93+
WARMUP_STABLE_DECAY = "warmup_stable_decay"
94+
WORMHOLE = "wormhole"
95+
2396

24-
if config is None:
25-
config = InternalModelConfig()
26-
enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True)
27-
init_args = InitArgs(
28-
model_name=base_model,
29-
max_seq_length=32768,
30-
load_in_4bit=True, # False for LoRA 16bit
31-
fast_inference=True, # Enable vLLM fast inference
32-
# vLLM args
33-
disable_log_stats=False,
34-
enable_prefix_caching=True,
35-
gpu_memory_utilization=(
36-
0.79 if enable_sleep_mode else 0.55
37-
), # Reduce if out of memory
38-
max_lora_rank=8,
39-
use_async=True,
40-
)
41-
if config.get("_decouple_vllm_and_unsloth", False):
42-
init_args["fast_inference"] = False
43-
init_args.pop("disable_log_stats")
44-
init_args.pop("enable_prefix_caching")
45-
init_args.pop("gpu_memory_utilization")
46-
init_args.pop("max_lora_rank")
47-
init_args.pop("use_async")
48-
engine_args = EngineArgs(
49-
disable_log_requests=True,
50-
# Multi-step processing is not supported for the Xformers attention backend
51-
# which is the fallback for devices with compute capability < 8.0
52-
num_scheduler_steps=(
53-
16
54-
if config.get("torchtune_args") is None
55-
and not config.get("_decouple_vllm_and_unsloth", False)
56-
and torch.cuda.get_device_capability()[0] >= 8
57-
else 1
58-
),
59-
enable_sleep_mode=enable_sleep_mode,
60-
generation_config="vllm",
61-
)
62-
engine_args.update(config.get("engine_args", {}))
63-
init_args.update(config.get("init_args", {}))
64-
if last_checkpoint_dir := get_last_checkpoint_dir(output_dir):
65-
init_args["model_name"] = last_checkpoint_dir
66-
if config.get("torchtune_args") is not None:
67-
engine_args["model"] = last_checkpoint_dir
68-
elif config.get("torchtune_args") is not None:
69-
engine_args["model"] = base_model
70-
if config.get("_decouple_vllm_and_unsloth", False):
71-
engine_args["model"] = base_model
72-
peft_args = PeftArgs(
73-
r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128
74-
target_modules=[
75-
"q_proj",
76-
"k_proj",
77-
"v_proj",
78-
"o_proj",
79-
"gate_proj",
80-
"up_proj",
81-
"down_proj",
82-
], # Remove QKVO if out of memory
83-
lora_alpha=16,
84-
# Enable long context finetuning
85-
use_gradient_checkpointing="unsloth", # type: ignore
86-
random_state=3407,
87-
)
88-
peft_args.update(config.get("peft_args", {}))
89-
trainer_args = TrainerArgs(
90-
learning_rate=5e-6,
91-
adam_beta1=0.9,
92-
adam_beta2=0.99,
93-
weight_decay=0.1,
94-
lr_scheduler_type="constant",
95-
optim="paged_adamw_8bit",
96-
logging_steps=1,
97-
per_device_train_batch_size=2,
98-
gradient_accumulation_steps=1, # Increase to 4 for smoother training
99-
num_generations=2, # Decrease if out of memory
100-
max_grad_norm=0.1,
101-
save_strategy="no",
102-
output_dir=output_dir,
103-
disable_tqdm=True,
104-
report_to="none",
105-
)
106-
trainer_args.update(config.get("trainer_args", {}))
107-
if config.get("torchtune_args") is not None:
108-
torchtune_args = TorchtuneArgs(model="qwen3_32b", model_type="QWEN3")
109-
torchtune_args.update(config.get("torchtune_args", {}) or {})
110-
else:
111-
torchtune_args = None
112-
return InternalModelConfig(
113-
init_args=init_args,
114-
engine_args=engine_args,
115-
peft_args=peft_args,
116-
trainer_args=trainer_args,
117-
torchtune_args=torchtune_args,
118-
_decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False),
119-
)
97+
# Vendored from transformers.trainer_utils.FSDPOption
98+
class FSDPOption(str, Enum):
99+
FULL_SHARD = "full_shard"
100+
SHARD_GRAD_OP = "shard_grad_op"
101+
NO_SHARD = "no_shard"
102+
HYBRID_SHARD = "hybrid_shard"
103+
HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2"
104+
OFFLOAD = "offload"
105+
AUTO_WRAP = "auto_wrap"
120106

121107

122108
class InternalModelConfig(TypedDict, total=False):
@@ -263,7 +249,7 @@ class TrainerArgs(TypedDict, total=False):
263249
accelerator_config: dict | str | None
264250
deepspeed: dict | str | None
265251
label_smoothing_factor: float
266-
optim: "OptimizerNames | str"
252+
optim: OptimizerNames | str
267253
optim_args: str | None
268254
adafactor: bool
269255
group_by_length: bool

src/art/local/backend.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -113,9 +113,10 @@ async def _get_service(self, model: TrainableModel) -> ModelService:
113113
from ..torchtune.service import TorchtuneService
114114
from ..unsloth.service import UnslothService
115115
from ..unsloth.decoupled_service import DecoupledUnslothService
116+
from ..dev.get_model_config import get_model_config
116117

117118
if model.name not in self._services:
118-
config = dev.get_model_config(
119+
config = get_model_config(
119120
base_model=model.base_model,
120121
output_dir=get_model_dir(model=model, art_path=self._path),
121122
config=model._internal_config,

0 commit comments

Comments
 (0)