-
Notifications
You must be signed in to change notification settings - Fork 13
SLURM + FSDP2 support #272
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…management and memory efficiency
…clarity and reliability
Co-authored-by: Siddharth Narayanan <[email protected]>
…4__robusttraining`) To optimize the existing code for speed, we can make use of more efficient operations for tensor handling and avoid unnecessary list operations within the function. Here is the rewritten program. ### Changes Made 1. Directly used the `torch.chunk` function to split the tensor and handle the resulting chunks as a tuple. 2. Precomputed the number of real chunks and initialized the `dummy_chunk_flags` list with appropriate lengths to avoid list appends in a loop. 3. Used tuple concatenation to efficiently add the necessary dummy chunks. 4. Converted the chunks to a list only once, just before returning, to maintain the same return type as before. These changes ensure that the operations, particularly list appending and tensor manipulations, are as efficient as possible.
⚡️ Codeflash found optimizations for this PR📄 89% (0.89x) speedup for
|
…4__robusttraining`) Sure, I can make the given code more efficient. Here are the main improvements. 1. Simplify the chunk splitting and dummy chunk creation to use fewer operations. 2. Avoid repetitive appending in a loop by pre-determining the length and constructing the final list accordingly. Here is the optimized version of the provided code. Improvements made. 1. Instead of using a conditional and loop to append dummy chunks, I pre-determine the number of necessary dummy chunks and extend the list in one operation. 2. Created the `dummy_chunk_flags` list in one go, thus avoiding repeated appending operations. With these changes, the function should run faster while maintaining the intended behavior.
⚡️ Codeflash found optimizations for this PR📄 12% (0.12x) speedup for
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Pull Request Overview
This PR introduces support for vanilla FSDP2 and SLURM-based parallelization for transformer models. Key changes include:
- A new backend implementation for transformer handling using FSDP2 in src/ldp/nn/handlers/transformer_handler_fsdp2.py.
- Updates to the existing transformer handler to integrate the new FSDP2 backend, including renaming of functions and configuration changes.
- Enhancements to SLURM cluster setup and worker initialization with new job directives and configuration parameters.
Reviewed Changes
Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.
File | Description |
---|---|
src/ldp/nn/handlers/transformer_handler_fsdp2.py | New module providing FSDP2-based parallel transformer handling. |
src/ldp/nn/handlers/transformer_handler.py | Updated to support both Accelerator and FSDP2 backends; configuration and function renames. |
src/ldp/nn/graph/llm_call_op.py, src/ldp/nn/agent/simple_local_agent.py, src/ldp/nn/init.py | Updated to propagate the new ParallelizationStrategy configuration. |
pyproject.toml | Updated torch version requirement to 2.6. |
f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank" | ||
f" {os.environ.get('RANK')}" | ||
) | ||
if model.training: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The generate method now forces the model into evaluation mode when it is in training mode. Consider documenting the implications of changing the model's mode here so that callers are aware of the side effects.
Copilot uses AI. Check for mistakes.
memory_per_worker = parallel_mode_config.memory_per_worker | ||
MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") | ||
value = int( | ||
memory_per_worker[:-MEMORY_UNIT_LENGTH] | ||
) # Get numeric value by removing last 2 chars (e.g. "GB") | ||
unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") | ||
assert len(unit) == MEMORY_UNIT_LENGTH, ( | ||
f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] The logic for parsing the memory_per_worker value assumes a fixed 2-character memory unit. Consider either using a more robust parsing method or documenting this assumption clearly.
memory_per_worker = parallel_mode_config.memory_per_worker | |
MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB") | |
value = int( | |
memory_per_worker[:-MEMORY_UNIT_LENGTH] | |
) # Get numeric value by removing last 2 chars (e.g. "GB") | |
unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB") | |
assert len(unit) == MEMORY_UNIT_LENGTH, ( | |
f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}" | |
) | |
import re | |
memory_per_worker = parallel_mode_config.memory_per_worker | |
match = re.match(r"^(\d+)([a-zA-Z]+)$", memory_per_worker) | |
if not match: | |
raise ValueError( | |
f"Invalid memory_per_worker format: {memory_per_worker}. Expected format: <value><unit> (e.g., '16GB')." | |
) | |
value, unit = match.groups() | |
value = int(value) # Convert numeric part to integer |
Copilot uses AI. Check for mistakes.
# return int(device) | ||
# return None | ||
|
||
# worker_to_cuda_device = self.client.run(get_cuda_visible_devices) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
[nitpick] There is a sizeable block of commented-out code for determining CUDA devices. Consider removing it if no longer needed or adding a comment to explain its purpose and future usage to improve code clarity.
Copilot uses AI. Check for mistakes.
This PR does two main things
Still WIP, but most done and you can start reviewing it for general comments.