|
| 1 | +import torch |
| 2 | + |
| 3 | +from .engine import EngineArgs |
| 4 | +from .model import InitArgs, PeftArgs, TrainerArgs |
| 5 | +from .torchtune import TorchtuneArgs |
| 6 | +from .model import InternalModelConfig |
| 7 | + |
| 8 | + |
| 9 | +def get_model_config( |
| 10 | + base_model: str, |
| 11 | + output_dir: str, |
| 12 | + config: "InternalModelConfig | None", |
| 13 | +) -> "InternalModelConfig": |
| 14 | + from ..local.checkpoints import get_last_checkpoint_dir |
| 15 | + |
| 16 | + if config is None: |
| 17 | + config = InternalModelConfig() |
| 18 | + enable_sleep_mode = config.get("engine_args", {}).get("enable_sleep_mode", True) |
| 19 | + init_args = InitArgs( |
| 20 | + model_name=base_model, |
| 21 | + max_seq_length=32768, |
| 22 | + load_in_4bit=True, # False for LoRA 16bit |
| 23 | + fast_inference=True, # Enable vLLM fast inference |
| 24 | + # vLLM args |
| 25 | + disable_log_stats=False, |
| 26 | + enable_prefix_caching=True, |
| 27 | + gpu_memory_utilization=( |
| 28 | + 0.79 if enable_sleep_mode else 0.55 |
| 29 | + ), # Reduce if out of memory |
| 30 | + max_lora_rank=8, |
| 31 | + use_async=True, |
| 32 | + ) |
| 33 | + if config.get("_decouple_vllm_and_unsloth", False): |
| 34 | + init_args["fast_inference"] = False |
| 35 | + init_args.pop("disable_log_stats") |
| 36 | + init_args.pop("enable_prefix_caching") |
| 37 | + init_args.pop("gpu_memory_utilization") |
| 38 | + init_args.pop("max_lora_rank") |
| 39 | + init_args.pop("use_async") |
| 40 | + engine_args = EngineArgs( |
| 41 | + disable_log_requests=True, |
| 42 | + # Multi-step processing is not supported for the Xformers attention backend |
| 43 | + # which is the fallback for devices with compute capability < 8.0 |
| 44 | + num_scheduler_steps=( |
| 45 | + 16 |
| 46 | + if config.get("torchtune_args") is None |
| 47 | + and not config.get("_decouple_vllm_and_unsloth", False) |
| 48 | + and torch.cuda.get_device_capability()[0] >= 8 |
| 49 | + else 1 |
| 50 | + ), |
| 51 | + enable_sleep_mode=enable_sleep_mode, |
| 52 | + generation_config="vllm", |
| 53 | + ) |
| 54 | + engine_args.update(config.get("engine_args", {})) |
| 55 | + init_args.update(config.get("init_args", {})) |
| 56 | + if last_checkpoint_dir := get_last_checkpoint_dir(output_dir): |
| 57 | + init_args["model_name"] = last_checkpoint_dir |
| 58 | + if config.get("torchtune_args") is not None: |
| 59 | + engine_args["model"] = last_checkpoint_dir |
| 60 | + elif config.get("torchtune_args") is not None: |
| 61 | + engine_args["model"] = base_model |
| 62 | + if config.get("_decouple_vllm_and_unsloth", False): |
| 63 | + engine_args["model"] = base_model |
| 64 | + peft_args = PeftArgs( |
| 65 | + r=8, # Choose any number > 0 ! Suggested 8, 16, 32, 64, 128 |
| 66 | + target_modules=[ |
| 67 | + "q_proj", |
| 68 | + "k_proj", |
| 69 | + "v_proj", |
| 70 | + "o_proj", |
| 71 | + "gate_proj", |
| 72 | + "up_proj", |
| 73 | + "down_proj", |
| 74 | + ], # Remove QKVO if out of memory |
| 75 | + lora_alpha=16, |
| 76 | + # Enable long context finetuning |
| 77 | + use_gradient_checkpointing="unsloth", # type: ignore |
| 78 | + random_state=3407, |
| 79 | + ) |
| 80 | + peft_args.update(config.get("peft_args", {})) |
| 81 | + trainer_args = TrainerArgs( |
| 82 | + learning_rate=5e-6, |
| 83 | + adam_beta1=0.9, |
| 84 | + adam_beta2=0.99, |
| 85 | + weight_decay=0.1, |
| 86 | + lr_scheduler_type="constant", |
| 87 | + optim="paged_adamw_8bit", |
| 88 | + logging_steps=1, |
| 89 | + per_device_train_batch_size=2, |
| 90 | + gradient_accumulation_steps=1, # Increase to 4 for smoother training |
| 91 | + num_generations=2, # Decrease if out of memory |
| 92 | + max_grad_norm=0.1, |
| 93 | + save_strategy="no", |
| 94 | + output_dir=output_dir, |
| 95 | + disable_tqdm=True, |
| 96 | + report_to="none", |
| 97 | + ) |
| 98 | + trainer_args.update(config.get("trainer_args", {})) |
| 99 | + if config.get("torchtune_args") is not None: |
| 100 | + torchtune_args = TorchtuneArgs(model="qwen3_32b", model_type="QWEN3") |
| 101 | + torchtune_args.update(config.get("torchtune_args", {}) or {}) |
| 102 | + else: |
| 103 | + torchtune_args = None |
| 104 | + return InternalModelConfig( |
| 105 | + init_args=init_args, |
| 106 | + engine_args=engine_args, |
| 107 | + peft_args=peft_args, |
| 108 | + trainer_args=trainer_args, |
| 109 | + torchtune_args=torchtune_args, |
| 110 | + _decouple_vllm_and_unsloth=config.get("_decouple_vllm_and_unsloth", False), |
| 111 | + ) |
0 commit comments