|
1 |
| -import torch |
2 |
| -from transformers.debug_utils import DebugOption |
3 |
| -from transformers.training_args import OptimizerNames |
4 |
| -from transformers.trainer_utils import ( |
5 |
| - FSDPOption, |
6 |
| - HubStrategy, |
7 |
| - IntervalStrategy, |
8 |
| - SaveStrategy, |
9 |
| - SchedulerType, |
10 |
| -) |
| 1 | +from enum import Enum |
11 | 2 | from typing_extensions import TypedDict
|
12 | 3 |
|
13 | 4 | from .engine import EngineArgs
|
14 | 5 | from .torchtune import TorchtuneArgs
|
15 | 6 |
|
16 | 7 |
|
17 |
| -def get_model_config( |
18 |
| - base_model: str, |
19 |
| - output_dir: str, |
20 |
| - config: "InternalModelConfig | None", |
21 |
| -) -> "InternalModelConfig": |
22 |
| - from ..local.checkpoints import get_last_checkpoint_dir |
| 8 | +# Vendored from transformers.training_args.OptimizerNames |
| 9 | +class OptimizerNames(str, Enum): |
| 10 | + """ |
| 11 | + Stores the acceptable string identifiers for optimizers. |
| 12 | + """ |
| 13 | + |
| 14 | + ADAMW_TORCH = "adamw_torch" |
| 15 | + ADAMW_TORCH_FUSED = "adamw_torch_fused" |
| 16 | + ADAMW_TORCH_XLA = "adamw_torch_xla" |
| 17 | + ADAMW_TORCH_NPU_FUSED = "adamw_torch_npu_fused" |
| 18 | + ADAMW_APEX_FUSED = "adamw_apex_fused" |
| 19 | + ADAFACTOR = "adafactor" |
| 20 | + ADAMW_ANYPRECISION = "adamw_anyprecision" |
| 21 | + ADAMW_TORCH_4BIT = "adamw_torch_4bit" |
| 22 | + ADAMW_TORCH_8BIT = "adamw_torch_8bit" |
| 23 | + ADEMAMIX = "ademamix" |
| 24 | + SGD = "sgd" |
| 25 | + ADAGRAD = "adagrad" |
| 26 | + ADAMW_BNB = "adamw_bnb_8bit" |
| 27 | + ADAMW_8BIT = "adamw_8bit" # just an alias for adamw_bnb_8bit |
| 28 | + ADEMAMIX_8BIT = "ademamix_8bit" |
| 29 | + LION_8BIT = "lion_8bit" |
| 30 | + LION = "lion_32bit" |
| 31 | + PAGED_ADAMW = "paged_adamw_32bit" |
| 32 | + PAGED_ADAMW_8BIT = "paged_adamw_8bit" |
| 33 | + PAGED_ADEMAMIX = "paged_ademamix_32bit" |
| 34 | + PAGED_ADEMAMIX_8BIT = "paged_ademamix_8bit" |
| 35 | + PAGED_LION = "paged_lion_32bit" |
| 36 | + PAGED_LION_8BIT = "paged_lion_8bit" |
| 37 | + RMSPROP = "rmsprop" |
| 38 | + RMSPROP_BNB = "rmsprop_bnb" |
| 39 | + RMSPROP_8BIT = "rmsprop_bnb_8bit" |
| 40 | + RMSPROP_32BIT = "rmsprop_bnb_32bit" |
| 41 | + GALORE_ADAMW = "galore_adamw" |
| 42 | + GALORE_ADAMW_8BIT = "galore_adamw_8bit" |
| 43 | + GALORE_ADAFACTOR = "galore_adafactor" |
| 44 | + GALORE_ADAMW_LAYERWISE = "galore_adamw_layerwise" |
| 45 | + GALORE_ADAMW_8BIT_LAYERWISE = "galore_adamw_8bit_layerwise" |
| 46 | + GALORE_ADAFACTOR_LAYERWISE = "galore_adafactor_layerwise" |
| 47 | + LOMO = "lomo" |
| 48 | + ADALOMO = "adalomo" |
| 49 | + GROKADAMW = "grokadamw" |
| 50 | + SCHEDULE_FREE_RADAM = "schedule_free_radam" |
| 51 | + SCHEDULE_FREE_ADAMW = "schedule_free_adamw" |
| 52 | + SCHEDULE_FREE_SGD = "schedule_free_sgd" |
| 53 | + APOLLO_ADAMW = "apollo_adamw" |
| 54 | + APOLLO_ADAMW_LAYERWISE = "apollo_adamw_layerwise" |
| 55 | + |
| 56 | + |
| 57 | +# Vendored from transformers.debug_utils.DebugOption |
| 58 | +class DebugOption(str, Enum): |
| 59 | + UNDERFLOW_OVERFLOW = "underflow_overflow" |
| 60 | + TPU_METRICS_DEBUG = "tpu_metrics_debug" |
| 61 | + |
| 62 | + |
| 63 | +# Vendored from transformers.trainer_utils.IntervalStrategy |
| 64 | +class IntervalStrategy(str, Enum): |
| 65 | + NO = "no" |
| 66 | + STEPS = "steps" |
| 67 | + EPOCH = "epoch" |
| 68 | + |
| 69 | + |
| 70 | +# Vendored from transformers.trainer_utils.SaveStrategy (which is an alias for IntervalStrategy) |
| 71 | +SaveStrategy = IntervalStrategy |
| 72 | + |
| 73 | + |
| 74 | +# Vendored from transformers.trainer_utils.HubStrategy |
| 75 | +class HubStrategy(str, Enum): |
| 76 | + END = "end" |
| 77 | + EVERY_SAVE = "every_save" |
| 78 | + CHECKPOINT = "checkpoint" |
| 79 | + ALL_CHECKPOINTS = "all_checkpoints" |
| 80 | + |
| 81 | + |
| 82 | +# Vendored from transformers.trainer_utils.SchedulerType |
| 83 | +class SchedulerType(str, Enum): |
| 84 | + LINEAR = "linear" |
| 85 | + COSINE = "cosine" |
| 86 | + COSINE_WITH_RESTARTS = "cosine_with_restarts" |
| 87 | + POLYNOMIAL = "polynomial" |
| 88 | + CONSTANT = "constant" |
| 89 | + CONSTANT_WITH_WARMUP = "constant_with_warmup" |
| 90 | + INVERSE_SQRT = "inverse_sqrt" |
| 91 | + REDUCE_ON_PLATEAU = "reduce_lr_on_plateau" |
| 92 | + COSINE_WITH_MIN_LR = "cosine_with_min_lr" |
| 93 | + WARMUP_STABLE_DECAY = "warmup_stable_decay" |
| 94 | + WORMHOLE = "wormhole" |
| 95 | + |
23 | 96 |
|
24 |
| - if config is None: |
25 |
| - config = InternalModelConfig() |
26 |
| - enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True) |
27 |
| - init_args = InitArgs( |
28 |
| - model_name=base_model, |
29 |
| - max_seq_length=32768, |
30 |
| - load_in_4bit=True, # False for LoRA 16bit |
31 |
| - fast_inference=True, # Enable vLLM fast inference |
32 |
| - # vLLM args |
33 |
| - disable_log_stats=False, |
34 |
| - enable_prefix_caching=True, |
35 |
| - gpu_memory_utilization=( |
36 |
| - 0.79 if enable_sleep_mode else 0.55 |
37 |
| - ), # Reduce if out of memory |
38 |
| - max_lora_rank=8, |
39 |
| - use_async=True, |
40 |
| - ) |
41 |
| - if config.get("_decouple_vllm_and_unsloth", False): |
42 |
| - init_args["fast_inference"] = False |
43 |
| - init_args.pop("disable_log_stats") |
44 |
| - init_args.pop("enable_prefix_caching") |
45 |
| - init_args.pop("gpu_memory_utilization") |
46 |
| - init_args.pop("max_lora_rank") |
47 |
| - init_args.pop("use_async") |
48 |
| - engine_args = EngineArgs( |
49 |
| - disable_log_requests=True, |
50 |
| - # Multi-step processing is not supported for the Xformers attention backend |
51 |
| - # which is the fallback for devices with compute capability < 8.0 |
52 |
| - num_scheduler_steps=( |
53 |
| - 16 |
54 |
| - if config.get("torchtune_args") is None |
55 |
| - and not config.get("_decouple_vllm_and_unsloth", False) |
56 |
| - and torch.cuda.get_device_capability()[0] >= 8 |
57 |
| - else 1 |
58 |
| - ), |
59 |
| - enable_sleep_mode=enable_sleep_mode, |
60 |
| - generation_config="vllm", |
61 |
| - ) |
62 |
| - engine_args.update(config.get("engine_args", {})) |
63 |
| - init_args.update(config.get("init_args", {})) |
64 |
| - if last_checkpoint_dir := get_last_checkpoint_dir(output_dir): |
65 |
| - init_args["model_name"] = last_checkpoint_dir |
66 |
| - if config.get("torchtune_args") is not None: |
67 |
| - engine_args["model"] = last_checkpoint_dir |
68 |
| - elif config.get("torchtune_args") is not None: |
69 |
| - engine_args["model"] = base_model |
70 |
| - if config.get("_decouple_vllm_and_unsloth", False): |
71 |
| - engine_args["model"] = base_model |
72 |
| - peft_args = PeftArgs( |
73 |
| - r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 |
74 |
| - target_modules=[ |
75 |
| - "q_proj", |
76 |
| - "k_proj", |
77 |
| - "v_proj", |
78 |
| - "o_proj", |
79 |
| - "gate_proj", |
80 |
| - "up_proj", |
81 |
| - "down_proj", |
82 |
| - ], # Remove QKVO if out of memory |
83 |
| - lora_alpha=16, |
84 |
| - # Enable long context finetuning |
85 |
| - use_gradient_checkpointing="unsloth", # type: ignore |
86 |
| - random_state=3407, |
87 |
| - ) |
88 |
| - peft_args.update(config.get("peft_args", {})) |
89 |
| - trainer_args = TrainerArgs( |
90 |
| - learning_rate=5e-6, |
91 |
| - adam_beta1=0.9, |
92 |
| - adam_beta2=0.99, |
93 |
| - weight_decay=0.1, |
94 |
| - lr_scheduler_type="constant", |
95 |
| - optim="paged_adamw_8bit", |
96 |
| - logging_steps=1, |
97 |
| - per_device_train_batch_size=2, |
98 |
| - gradient_accumulation_steps=1, # Increase to 4 for smoother training |
99 |
| - num_generations=2, # Decrease if out of memory |
100 |
| - max_grad_norm=0.1, |
101 |
| - save_strategy="no", |
102 |
| - output_dir=output_dir, |
103 |
| - disable_tqdm=True, |
104 |
| - report_to="none", |
105 |
| - ) |
106 |
| - trainer_args.update(config.get("trainer_args", {})) |
107 |
| - if config.get("torchtune_args") is not None: |
108 |
| - torchtune_args = TorchtuneArgs(model="qwen3_32b", model_type="QWEN3") |
109 |
| - torchtune_args.update(config.get("torchtune_args", {}) or {}) |
110 |
| - else: |
111 |
| - torchtune_args = None |
112 |
| - return InternalModelConfig( |
113 |
| - init_args=init_args, |
114 |
| - engine_args=engine_args, |
115 |
| - peft_args=peft_args, |
116 |
| - trainer_args=trainer_args, |
117 |
| - torchtune_args=torchtune_args, |
118 |
| - _decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False), |
119 |
| - ) |
| 97 | +# Vendored from transformers.trainer_utils.FSDPOption |
| 98 | +class FSDPOption(str, Enum): |
| 99 | + FULL_SHARD = "full_shard" |
| 100 | + SHARD_GRAD_OP = "shard_grad_op" |
| 101 | + NO_SHARD = "no_shard" |
| 102 | + HYBRID_SHARD = "hybrid_shard" |
| 103 | + HYBRID_SHARD_ZERO2 = "hybrid_shard_zero2" |
| 104 | + OFFLOAD = "offload" |
| 105 | + AUTO_WRAP = "auto_wrap" |
120 | 106 |
|
121 | 107 |
|
122 | 108 | class InternalModelConfig(TypedDict, total=False):
|
@@ -263,7 +249,7 @@ class TrainerArgs(TypedDict, total=False):
|
263 | 249 | accelerator_config: dict | str | None
|
264 | 250 | deepspeed: dict | str | None
|
265 | 251 | label_smoothing_factor: float
|
266 |
| - optim: "OptimizerNames | str" |
| 252 | + optim: OptimizerNames | str |
267 | 253 | optim_args: str | None
|
268 | 254 | adafactor: bool
|
269 | 255 | group_by_length: bool
|
|
0 commit comments