-
Notifications
You must be signed in to change notification settings - Fork 296
feat: Multi-device training #163
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
- Introduced a new Jupyter notebook for vLLM usage examples. - Implemented `get_llm` function for model downloading and initialization. - Added memory management functions for offloading and waking up the KV cache. - Created a new module for multi-device support in vLLM.
- Added context variable management for worker instances in vllm.py. - Implemented functions to pause and resume worker processes in vllm.ipynb. - Updated Jupyter notebook to demonstrate new worker management capabilities.
- Added patch_worker_sleep function to manage worker sleep state. - Enhanced execute_model to wake up sleeping workers before execution. - Integrated OS signal handling for process suspension.
… worker management in vllm.py - Added asyncio support for managing worker sleep state in vllm.ipynb. - Updated signal handling to include resuming workers after sleep. - Enhanced get_llm function to utilize the new patch_worker_sleep functionality.
…pynb - Set execution count for the initial code cell to 1. - Added import for torch to facilitate dynamic tensor parallel size. - Updated tensor_parallel_size to use the number of available CUDA devices. - Reset execution count for a later code cell to null.
… functionality in vllm.py - Set execution count for a code cell in vllm.ipynb to 5. - Simplified worker sleep management by removing the execute_model override in vllm.py. - Ensured workers wake up automatically after sleep without needing to check a separate state.
… vllm.ipynb and vllm.py - Set execution counts for multiple code cells in vllm.ipynb to ensure proper execution flow. - Enhanced worker sleep management in vllm.py by introducing controlled sleep durations and removing unnecessary garbage collection calls. - Improved signal handling for worker wake-up processes to ensure smoother operation.
- Introduced a new method to patch the Executor.get_class static method for customized executor selection. - Implemented a CustomExecutor subclass that extends the base executor with additional functionality, including a modified sleep method. - Enhanced the dynamic creation of executor classes to improve flexibility in executor management.
….ipynb and vllm.py - Set execution counts to null for multiple code cells in vllm.ipynb to reset execution state. - Removed unnecessary output data from code cells to clean up notebook execution results. - Updated the sleep task management in vllm.py to improve clarity and functionality.
…ss functions in vllm.py - Eliminated the patch_worker_sleep function to streamline worker management. - Removed the patch_executor_get_class function, simplifying executor class handling. - Cleaned up imports and unnecessary comments to enhance code clarity.
- Introduced a new Jupyter notebook (trl.ipynb) for training with the TRL framework. - Included code for setting up the environment, loading datasets, and configuring the GRPOTrainer. - Enhanced notebook with HTML styling for better output visualization.
…ion files - Added torchtune as a dependency in pyproject.toml for enhanced model tuning capabilities. - Introduced new configuration files for torchtune, including a YAML config for multi-device fine-tuning and a training script. - Created initial structure for torchtune with an __init__.py file and a recipe.py for training logic. - Implemented a shell script to facilitate the training process with specified configurations. - Enhanced the uv.lock file to include new package dependencies and versions.
…c configuration models - Introduced a new THIRD-PARTY-NOTICES file to acknowledge the use of PyTorch TorchTune and its BSD 3-Clause License. - Enhanced recipe.py by adding multiple Pydantic configuration models for components such as Model, Tokenizer, Optimizer, and more, improving configuration management for distributed training. - Updated the FullFinetuneRecipeDistributed class to utilize the new RecipeConfig structure for better clarity and maintainability.
…nce execution flow - Modified the training script to correctly export the model directory using the `uv run` command for downloading the model. - Added an export for the `TORCHTUNE_DIR` to streamline the execution of the tuning process. - Improved the command structure for better clarity and maintainability in the training workflow.
- Added a newline at the end of the train.sh script to ensure proper file formatting and adherence to coding standards.
…ling and type assertions - Adjusted max_steps_per_epoch to default to 0 if not set, ensuring robustness in training configurations. - Enhanced type safety by asserting model types and optimizer handling, specifically for FSDPModule and TransformerDecoder. - Refactored optimizer retrieval logic to streamline the process and improve clarity in the training workflow. - Added type hints and ignored type errors in specific areas to maintain compatibility with type checkers.
…t execution - Increased batch size from 2 to 64 to enhance training efficiency. - Adjusted gradient accumulation steps from 8 to 1 to align with the new batch size. - Modified the training script to ensure proper command execution by appending additional arguments.
…d services - Added TorchtuneArgs to model configuration for enhanced tuning capabilities. - Introduced TorchtuneService to manage model training and server interactions. - Updated LocalBackend to conditionally use TorchtuneService based on configuration. - Refactored service interfaces to accommodate new training logic and improve modularity. - Created initial structure for Unsloth integration, including service and state management.
…rity - Changed the import path for openai_server_task in vllm.ipynb to reflect the new module structure. - Added type ignore comments in guided_completion.py to address type checker warnings. - Modified the environment variable loading in backend.py to exclude None values for cleaner configuration. - Updated function signatures in utils.py to enhance type safety and clarity in return types.
- Updated import paths in vllm.ipynb and related files to reflect the new module structure, consolidating imports under the vllm module. - Deleted the multidevice.vllm.py file as its functionality has been integrated into the vllm module. - Adjusted references in service and state management files to ensure compatibility with the new import structure.
- Modified the `get_openai_server_config` function to allow `lora_path` to be optional, enhancing flexibility in server configuration. - Adjusted the handling of `lora_modules` to conditionally include the path based on the presence of `lora_path`. - Updated the `start_openai_server` method in `TorchtuneService` to utilize the new configuration structure and improved type handling for `EngineArgs`.
…lBackend - Moved the imports for TorchtuneService and UnslothService into the _get_service method of LocalBackend for better encapsulation. - Removed the PackedDataset class from pack.py to streamline the codebase, as it was unused. - Deleted the empty multidevice module to simplify the project structure. - Added a print statement in TorchtuneService to log the training model name and configuration for better debugging.
- Updated the training script in train.ipynb to initialize the model with LocalBackend and improved configuration for Torchtune. - Refactored model configuration to conditionally include TorchtuneArgs and adjusted scheduler steps based on the presence of these arguments. - Improved the handling of engine arguments in the model and OpenAI server configuration for better flexibility and clarity. - Enhanced the LocalBackend service initialization to ensure proper service selection based on the configuration provided.
- Removed the comprehensive log filter and replaced it with a more streamlined logger suppression mechanism in logs.py. - Introduced a patched getLogger function to suppress logs from specific logger prefixes, improving log management. - Cleaned up unnecessary print statements in TorchtuneService and optimized the training yield logic for better performance.
- Changed the model name from "Qwen/Qwen3-32B" to "006" for clarity in identification. - Simplified the initialization of `init_args` in the model configuration for better readability. - Restored the training call in the training function to ensure proper execution of the training process. - Added a new cell to terminate sandboxes after model registration, enhancing resource management.
- Changed the model name in train.ipynb from "006" to "007" for better version tracking. - Added a safeguard in LocalBackend to ensure logging steps are never negative, improving robustness in logging behavior.
…ered_swe_smith_instances_df - Included "tornadoweb__tornado.d5ac65c1" in the list of dependencies for improved instance filtering accuracy.
…eRecipeDistributed - Eliminated the loss configuration from the RecipeConfig and associated loss initialization in the setup method. - Updated comments and code to reflect the removal of loss handling, streamlining the training process. - Adjusted model behavior to skip the output layer, enhancing clarity in the training flow.
- Added "aiolimiter" as a dependency in pyproject.toml for rate limiting functionality. - Introduced AsyncLimiter in rollout.py to manage request rates, enhancing performance and preventing overload during execution. - Updated uv.lock to include the new aiolimiter package and its version details. - Created a new config.py file to define configuration models for various components, improving code organization and maintainability.
- Adjusted logging configuration to set the OpenAI base client logger level to WARNING, reducing verbosity of retry messages and improving log clarity.
- Updated the retry mechanism in the rollout function to include StreamTerminatedError, improving resilience against stream interruptions during execution.
…red_swe_smith_instances_df - Included "pyupio__safety.7654596b" in the list of dependencies for improved instance filtering accuracy.
…n service - Added new model literals to TorchtuneModel and TorchtuneModelType for comprehensive model support. - Updated TorchtuneArgs to require model and model_type fields, improving type safety. - Enhanced error handling in TorchtuneService to provide clearer logging for early termination of training processes.
…orker context management - Introduced a context manager for timing operations in the TorchtuneService, improving performance tracking during critical tasks. - Updated logging levels based on verbosity and worker rank to reduce unnecessary log output. - Refactored state loading and weight loading processes to utilize the new timing context manager for better performance insights. - Modified worker context management to support an extended worker class, enhancing type safety and functionality.
…ndling in TorchtuneService - Added a new property, torchtune_args, to encapsulate access to the torchtune_args configuration, enhancing code clarity and reducing redundancy. - Updated references to torchtune_args throughout the service methods to utilize the new property, streamlining configuration access. - Improved assertions for configuration presence to ensure robustness during service initialization.
…gument handling - Introduced a dynamic sleep level based on the presence of unfinished requests in the output processor, optimizing worker sleep behavior. - Refactored subprocess argument construction to ensure consistent output directory and metric logger settings, enhancing configuration clarity and maintainability.
…n TorchtuneService - Refactored the sleep mechanism to streamline state loading and PID management, improving clarity and maintainability. - Replaced direct file operations with Path.unlink for better error handling and code readability. - Introduced a dedicated sleep function to encapsulate worker sleep logic, enhancing modularity and separation of concerns.
…rvice - Updated the TorchtuneArgs to include a new async_weight_syncing field, improving configuration flexibility. - Modified weight handling in the TorchtuneService to support asynchronous weight syncing, enhancing performance during training. - Refactored file operations for weights to improve error handling and clarity, ensuring smoother integration with the training process.
…chtuneService - Refactored the sleep function to enhance clarity and maintainability by simplifying PID management and removing redundant error handling. - Improved the logic for waking up the worker by clarifying conditions related to the presence of PID files. - Cleaned up whitespace and formatting in the update_weights function for better readability.
…tuneRecipeDistributed - Moved CPU state management and GPU signaling code to the correct location in FullFinetuneRecipeDistributed, improving clarity and ensuring proper resource handling. - Refactored weight update logic in TorchtuneService to utilize an asynchronous task for better performance and responsiveness during training. - Enhanced the update_worker_weights method for improved readability and maintainability, ensuring robust weight synchronization.
… in FullFinetuneRecipeDistributed - Eliminated unnecessary logging of negative steps in LocalBackend, simplifying the step handling logic. - Introduced a sleep period in FullFinetuneRecipeDistributed to ensure proper resource management after GPU signaling.
…roved resource management - Adjusted the sleep duration from 2 seconds to 4 seconds to ensure proper resource handling after GPU signaling, enhancing stability during training.
- Changed execution count to null and removed unnecessary HTML output in the training notebook for cleaner execution. - Updated model name from "020" to "025" and refined engine arguments for improved GPU memory utilization during training. - Commented out the termination of sandboxes to prevent premature execution during the training process.
… LoRARequest patching - Replaced asdict with replace for more efficient engine argument updates in ModelState. - Updated patch_lora_request to ensure compatibility between Unsloth and vLLM LoRARequest types, adding necessary attributes for seamless integration.
- Modified the uvicorn.run call in cli.py to specify the asyncio loop, ensuring compatibility with asynchronous operations.
- Removed obsolete package entries for 'aiolimiter', 'antlr4-python3-runtime', 'blobfile', and 'kagglehub'. - Updated version for 'compressed-tensors' from 0.10.1 to 0.9.3. - Adjusted 'flask-cors' version from 6.0.1 to 6.0.0 and 'vllm' from 0.9.1 to 0.8.5.post1. - Changed 'xformers' version from 0.0.30 to 0.0.29.post3 and 'xgrammar' from 0.1.19 to 0.1.18. - Streamlined dependencies for 'triton' to improve compatibility across platforms.
…ry imports - Added import for 'torch' to the training notebook for enhanced model functionality. - Updated 'tensor_parallel_size' to dynamically use the number of available GPUs. - Enabled 'async_weight_syncing' in 'TorchtuneArgs' for improved weight synchronization during training. - Updated Python version in notebook metadata from 3.10.13 to 3.10.16 for compatibility.
Hi @bradhilton , thanks for notifying me! Am I right in saying (quick look), that ART doesn't split GPUs between training and inference. Instead, it uses all GPUs sequentially? Do you think an (8x in practice) A100 (80GB) is fine for Qwen3-32B? |
@Danau5tin yep, GPUs are colocated. This means you get high GPU utilization without having to figure out the optimal balance between inference and training GPUs. If there is enough interest from those who are more concerned about absolute best wall time, whatever the cost, than utilization or those who are willing to find the optimal number of training and inference GPUs themselves to maximize utilization, then maybe we will add support for separate deployments in the future. I've been using 8xH200s for Qwen3-32B which have about double the memory, but if your sequence length is not too long and you lower vLLM's GPU utilization you may be okay. You can also try enabling activation offloading to save additional memory. Here's how you initialize your trainable model to use the multi-device backend at the moment: model = art.TrainableModel(
name="model-name",
project="project-name",
base_model="Qwen/Qwen3-32B",
_internal_config=art.dev.InternalModelConfig(
engine_args=art.dev.EngineArgs(
tensor_parallel_size=torch.cuda.device_count(), # need to specify the tensor parallel size for vLLM, at least for now
gpu_memory_utilization=0.9 # defaults to 0.9, but you will probably need to lower this to avoid OOM errors
),
torchtune_args=art.dev.TorchtuneArgs(
model="qwen3_32b", # needed for torchtune model initialization
model_type="QWEN3", # needed for torchtune checkpointing
async_weight_syncing=False # defaults to False, changes the learning guaruntees to enable but speeds up training
enable_activation_offloading=False # defaults to False, enable to save VRAM for a small speed hit
),
),
) |
Hey @bradhilton, so I migrated my project over to ART, tried out the code, and all looked great! Here are some logs I have:
Created with this code: all_rewards = [
trajectory.reward for group in train_groups for trajectory in group.trajectories
]
logger.info(
f"Collected {len(all_rewards)} which have rewards: {all_rewards}. Stnd dev: {torch.std(torch.tensor(all_rewards)):.4f}"
) Not sure where that warning at the bottom comes from, but it comes up no matter what the rewards in the group are, and I can see from other logs that the trajectories are actually happening as expected. So to see if it was because of the multi-gpu setup, I switched to single GPU (removed this internal config below - just commented it out):
When I switched to the standard way of training on a single GPU, it just seemed to hang on the first rollout. So not sure what is happening there either because the rollouts worked fine with the multi-gpu setup. It has been a long day, so I have only tried the single GPU setup - with Unsloth - once before I wrote this! Perhaps If i tried again it would work. I am using Qwen 0.6B for reference (Just to try and get a successful run, then I will increase model params), on 2x A100 GPUs. |
@Danau5tin make sure that the trajectories in a group have differing rewards. Standard deviation should be calculated in group and not between groups because that does not affect the advantage calculation. If you want more timely help, feel free to join the discord! |
No description provided.