Skip to content

SLURM + FSDP2 support #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open

SLURM + FSDP2 support #272

wants to merge 42 commits into from

Conversation

kwanUm
Copy link
Collaborator

@kwanUm kwanUm commented Apr 7, 2025

This PR does two main things

  • Adds alternative backend implementation for TransformerHandler that supports vanilla FSDP2 (without accelerator)
  • Adds back support for TransformerHandler over a SLURMCluster for SLURM with no support for --gres option.

Still WIP, but most done and you can start reviewing it for general comments.

codeflash-ai bot added a commit that referenced this pull request Apr 7, 2025
…4__robusttraining`)

To optimize the existing code for speed, we can make use of more efficient operations for tensor handling and avoid unnecessary list operations within the function. Here is the rewritten program.



### Changes Made
1. Directly used the `torch.chunk` function to split the tensor and handle the resulting chunks as a tuple.
2. Precomputed the number of real chunks and initialized the `dummy_chunk_flags` list with appropriate lengths to avoid list appends in a loop.
3. Used tuple concatenation to efficiently add the necessary dummy chunks.
4. Converted the chunks to a list only once, just before returning, to maintain the same return type as before.

These changes ensure that the operations, particularly list appending and tensor manipulations, are as efficient as possible.
Copy link

codeflash-ai bot commented Apr 7, 2025

⚡️ Codeflash found optimizations for this PR

📄 89% (0.89x) speedup for TensorChunker._split_value in src/ldp/nn/handlers/chunking.py

⏱️ Runtime : 2.60 milliseconds 1.38 millisecond (best of 82 runs)

I created a new dependent PR with the suggested changes. Please review:

If you approve, it will be merged into this PR (branch 14__robusttraining).

Ori Kabeli added 3 commits April 7, 2025 13:31
codeflash-ai bot added a commit that referenced this pull request Apr 7, 2025
…4__robusttraining`)

Sure, I can make the given code more efficient. Here are the main improvements.
1. Simplify the chunk splitting and dummy chunk creation to use fewer operations.
2. Avoid repetitive appending in a loop by pre-determining the length and constructing the final list accordingly.

Here is the optimized version of the provided code.



Improvements made.
1. Instead of using a conditional and loop to append dummy chunks, I pre-determine the number of necessary dummy chunks and extend the list in one operation.
2. Created the `dummy_chunk_flags` list in one go, thus avoiding repeated appending operations.

With these changes, the function should run faster while maintaining the intended behavior.
Copy link

codeflash-ai bot commented Apr 7, 2025

⚡️ Codeflash found optimizations for this PR

📄 12% (0.12x) speedup for TensorChunker._split_value in src/ldp/nn/handlers/chunking.py

⏱️ Runtime : 670 microseconds 600 microseconds (best of 103 runs)

I created a new dependent PR with the suggested changes. Please review:

If you approve, it will be merged into this PR (branch 14__robusttraining).

@kwanUm kwanUm changed the base branch from main to 13__nvmlsupport April 8, 2025 11:00
@kwanUm kwanUm requested a review from sidnarayanan April 8, 2025 11:01
Base automatically changed from 13__nvmlsupport to main April 23, 2025 12:16
@kwanUm kwanUm self-assigned this Apr 30, 2025
@kwanUm kwanUm marked this pull request as ready for review April 30, 2025 13:19
@Copilot Copilot AI review requested due to automatic review settings April 30, 2025 13:19
Copy link
Contributor

@Copilot Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR introduces support for vanilla FSDP2 and SLURM-based parallelization for transformer models. Key changes include:

  • A new backend implementation for transformer handling using FSDP2 in src/ldp/nn/handlers/transformer_handler_fsdp2.py.
  • Updates to the existing transformer handler to integrate the new FSDP2 backend, including renaming of functions and configuration changes.
  • Enhancements to SLURM cluster setup and worker initialization with new job directives and configuration parameters.

Reviewed Changes

Copilot reviewed 7 out of 7 changed files in this pull request and generated 3 comments.

File Description
src/ldp/nn/handlers/transformer_handler_fsdp2.py New module providing FSDP2-based parallel transformer handling.
src/ldp/nn/handlers/transformer_handler.py Updated to support both Accelerator and FSDP2 backends; configuration and function renames.
src/ldp/nn/graph/llm_call_op.py, src/ldp/nn/agent/simple_local_agent.py, src/ldp/nn/init.py Updated to propagate the new ParallelizationStrategy configuration.
pyproject.toml Updated torch version requirement to 2.6.

f"model.generate() input_ids shape: {kwargs['input_ids'].shape}, rank"
f" {os.environ.get('RANK')}"
)
if model.training:
Copy link
Preview

Copilot AI Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The generate method now forces the model into evaluation mode when it is in training mode. Consider documenting the implications of changing the model's mode here so that callers are aware of the side effects.

Copilot uses AI. Check for mistakes.

Comment on lines +596 to +604
memory_per_worker = parallel_mode_config.memory_per_worker
MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB")
value = int(
memory_per_worker[:-MEMORY_UNIT_LENGTH]
) # Get numeric value by removing last 2 chars (e.g. "GB")
unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB")
assert len(unit) == MEMORY_UNIT_LENGTH, (
f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}"
)
Copy link
Preview

Copilot AI Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] The logic for parsing the memory_per_worker value assumes a fixed 2-character memory unit. Consider either using a more robust parsing method or documenting this assumption clearly.

Suggested change
memory_per_worker = parallel_mode_config.memory_per_worker
MEMORY_UNIT_LENGTH = 2 # Memory units are typically 2 chars (e.g. "GB", "MB")
value = int(
memory_per_worker[:-MEMORY_UNIT_LENGTH]
) # Get numeric value by removing last 2 chars (e.g. "GB")
unit = memory_per_worker[-MEMORY_UNIT_LENGTH:] # Get unit (e.g. "GB")
assert len(unit) == MEMORY_UNIT_LENGTH, (
f"Memory unit must be {MEMORY_UNIT_LENGTH} characters long, got {unit}"
)
import re
memory_per_worker = parallel_mode_config.memory_per_worker
match = re.match(r"^(\d+)([a-zA-Z]+)$", memory_per_worker)
if not match:
raise ValueError(
f"Invalid memory_per_worker format: {memory_per_worker}. Expected format: <value><unit> (e.g., '16GB')."
)
value, unit = match.groups()
value = int(value) # Convert numeric part to integer

Copilot uses AI. Check for mistakes.

# return int(device)
# return None

# worker_to_cuda_device = self.client.run(get_cuda_visible_devices)
Copy link
Preview

Copilot AI Apr 30, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] There is a sizeable block of commented-out code for determining CUDA devices. Consider removing it if no longer needed or adding a comment to explain its purpose and future usage to improve code clarity.

Copilot uses AI. Check for mistakes.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant