Skip to content

SLURM + FSDP2 support #272

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 42 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
38912bd
Rollout manager with exception counter & chunking fix; .gitignore update
Feb 18, 2025
13c9622
nits
Feb 18, 2025
a7c1bf2
nits
Feb 18, 2025
ae71669
nits
Feb 23, 2025
39b0fc0
nits
Feb 23, 2025
96e870d
nits
Feb 23, 2025
da92fbf
nits
Feb 23, 2025
2ba5898
nits
Feb 24, 2025
6297786
nits
Mar 3, 2025
e4ccb45
nits
Mar 4, 2025
d81307d
Merge remote-tracking branch 'origin/main' into 13__nvmlsupport
Mar 4, 2025
6ed30bd
nits
Mar 4, 2025
d245e3d
nits
Mar 4, 2025
51210df
nits
Mar 4, 2025
8376e28
Merge remote-tracking branch 'origin/main' into 13__nvmlsupport
Mar 10, 2025
3d7c20e
Refactor Dask handling in transformer handler for improved exception …
Mar 12, 2025
15b936d
Remove test OutOfMemoryError raise in AsyncTransformerInterface
Mar 12, 2025
ad37501
Refactor exception handling in ParallelAsyncTransformer for improved …
Mar 12, 2025
62f42d1
nits
Mar 12, 2025
8f9a976
Merge remote-tracking branch 'origin/main' into 13__nvmlsupport
Mar 12, 2025
1b486cf
nits code review
Mar 19, 2025
08ab5b6
Merge remote-tracking branch 'origin/main' into 13__nvmlsupport
Mar 19, 2025
bddaa92
nits
Mar 19, 2025
310131e
nits
Mar 19, 2025
4e40908
nits
Mar 19, 2025
5fb76d4
nit
Mar 19, 2025
5cd42d8
Merge remote-tracking branch 'origin/main' into 13__nvmlsupport
Mar 31, 2025
e6ee090
nits comments fix
Mar 31, 2025
5651e54
Update src/ldp/nn/agent/simple_local_agent.py
kwanUm Mar 31, 2025
eb24519
nits
Mar 31, 2025
f7a8247
nits
Mar 31, 2025
decced0
fsdp2 support + slurm support
Apr 7, 2025
caf9edf
SLURM + FSDP2 support
Apr 7, 2025
94fba95
nit
Apr 7, 2025
c809dc3
nit
Apr 7, 2025
62a499f
nit
Apr 7, 2025
37b1d27
Merge main branch into 14__robusttraining
Apr 23, 2025
0cdba00
nits
Apr 23, 2025
770bc06
update fsdp2 backend
Apr 29, 2025
97ae0c0
Merge remote-tracking branch 'origin/main' into 14__robusttraining
Apr 29, 2025
76ee37f
nits
Apr 29, 2025
877ad6b
nits
Apr 30, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ nn = [
"dask-jobqueue",
"dask[distributed]",
"tokenizers>0.20",
"torch>=2.5,<2.7", # Temporarily pin <2.6 until someone fixes our CI with torch 2.6
"torch>=2.6,<2.7", # Use torch 2.6
"transformers>=4.46",
"wandb",
]
Expand Down Expand Up @@ -153,7 +153,7 @@ explicit_package_bases = true
mypy_path = "$MYPY_CONFIG_FILE_DIR/src,$MYPY_CONFIG_FILE_DIR/packages/lmi/src"
# Specifies the OS platform for the target program, for example darwin or win32
# (meaning OS X or Windows, respectively). The default is the current platform
# as revealed by Pythons sys.platform variable.
# as revealed by Python's sys.platform variable.
platform = "linux"
# Comma-separated list of mypy plugins.
plugins = ["pydantic.mypy"]
Expand Down
2 changes: 2 additions & 0 deletions src/ldp/nn/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
ParallelTransformerHandler,
TransformerHandler,
TransformerHandlerConfig,
ParallelizationStrategy,
collate_fn_transformer_left_pad,
collate_fn_transformer_right_pad,
decollate_fn_transformer_decoder,
Expand All @@ -35,6 +36,7 @@
"TorchDType",
"TransformerHandler",
"TransformerHandlerConfig",
"ParallelizationStrategy",
"collate_fn_transformer_left_pad",
"collate_fn_transformer_right_pad",
"decollate_fn_transformer_decoder",
Expand Down
4 changes: 3 additions & 1 deletion src/ldp/nn/agent/simple_local_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from ldp.nn.handlers.chunking import TensorChunker
from ldp.nn.handlers.transformer_handler import (
ParallelModeConfig,
ParallelizationStrategy,
logits_to_logprobs,
)
from ldp.nn.lm_config import LMConfig as _LMConfig
Expand All @@ -31,7 +32,7 @@ class AgentLMConfig(_LMConfig):

# distribution
parallel_mode: ParallelModeConfig | None = None

parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR
# sampling parameters
temperature: float = 1.0
max_new_tokens: int = 50
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
batch_size=self.llm_model.batch_size,
max_wait_interval=self.llm_model.max_wait_interval,
parallel_mode_config=self.llm_model.parallel_mode,
parallel_strategy=self.llm_model.parallel_strategy,
)

async def init_state(self, tools: list[Tool]) -> SimpleAgentState:
Expand Down
3 changes: 3 additions & 0 deletions src/ldp/nn/graph/llm_call_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
LMType,
ParallelModeConfig,
TransformerHandlerConfig,
ParallelizationStrategy,
collate_fn_transformer_left_pad,
decollate_fn_transformer_decoder,
)
Expand All @@ -40,6 +41,7 @@ def __init__(
batch_size: int = 1,
max_wait_interval: float = 0.1,
parallel_mode_config: ParallelModeConfig | None = None,
parallel_strategy: ParallelizationStrategy = ParallelizationStrategy.ACCELERATOR,
) -> None:
super().__init__()

Expand All @@ -51,6 +53,7 @@ def __init__(
batch_size=batch_size,
max_wait_interval=max_wait_interval,
parallel_mode_config=parallel_mode_config,
parallel_strategy=parallel_strategy,
# constant configuration
lm_type=LMType.GENERATION,
module_call_fn=AsyncTransformerInterface.model_generate,
Expand Down
Loading
Loading