Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 10 additions & 8 deletions torchtitan/experiments/simple_fsdp/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,14 +52,16 @@ SimpleFSDP relies on compiler backend to perform optimizations (i.e., bucketing
1. no optimization: default torch.compile backends (e.g., "inductor", "aot_eager", "eager")

2. auto optimization: perform auto-bucketing & reordering without user inputs. **Note: it is not guaranteed that users will get the most optimized training performance**
- "aot_eager_autobucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend.


users can specify the pass (e.g., "aot_eager_autobucketing") via additional configs:

```bash
--compile.model_backend_override "aot_eager_autobucketing"
```
- "auto_bucketing": perform autobucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend).
```bash
--compile.backend "aot_eager" --compile.graph_passes "auto_bucketing"
```

3. manual optimization: perform manual bucketing & reordering with user FQN inputs.
- "transformer_block_bucketing": perform manual bucketing at aten fx-level, and perform code execution with aot_eager backend. (We also support `inductor` backend).
```bash
--compile.backend "aot_eager" --compile.graph_passes "transformer_block_bucketing"
```

### Citation

Expand Down
135 changes: 109 additions & 26 deletions torchtitan/experiments/simple_fsdp/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,49 +8,132 @@

import torch
import torch._functorch.config as functorch_config
from torchtitan.tools.logging import logger

from .job_config import Compile as CompileConfig

from .reshard_after_forward import annotate_fsdp_all_gather


def get_compile_backend(
backend_name: str, fsdp_reshard_after_forward: bool
def get_compile_backend_with_passes(
compile_config: CompileConfig,
fsdp_reshard_after_forward: bool,
fsdp_manual_buckets: list[list[str] | str] | None,
) -> callable:
# return the compile backends used in SimpleFSDP training
# Step1: check if backend_name is inside available torch.compile backends
# Step2: check if the backend_name has been registered as a customized backend
available_torch_backend = torch._dynamo.list_backends(exclude_tags=())

if backend_name in available_torch_backend:
backend = torch._dynamo.lookup_backend(backend_name)
elif backend_name == "aot_eager_autobucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
"""
Apply compile backend and additional graph passes.
Args:
compile_config: compile configs to apply torch.compile.
fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP,
which is implemented via a customized AC graph pass.
fsdp_manual_buckets: used in transformer_block_bucketing to define which modules should be bucketed.
Returns:
compile backend with applied graph passes.
"""
backend = torch._dynamo.lookup_backend(compile_config.backend)

# Apply bucketing and overlapping pass on fwd and bwd graph separately
if compile_config.graph_passes == "auto_bucketing":
# Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend
# The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960
from torch._inductor.config import aten_distributed_optimizations as dist_opts
from torch._inductor.fx_passes.overlap_scheduling import (
schedule_overlap_bucketing,
)

dist_opts.collective_bucketing = True
dist_opts.insert_overlap_deps = False
torch._inductor.config.allow_buffer_reuse = False

def aten_autobucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
schedule_overlap_bucketing(gm)
gm.recompile()
return gm

backend = aot_autograd_backend(
fw_compiler=aten_autobucketing_reordering_pass,
bw_compiler=aten_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
if compile_config.backend == "aot_eager":
from torch._dynamo.backends.common import (
aot_autograd as aot_autograd_backend,
)

def aot_eager_autobucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
schedule_overlap_bucketing(gm)
gm.recompile()
return gm

dist_opts.insert_overlap_deps = False
backend = aot_autograd_backend(
fw_compiler=aot_eager_autobucketing_reordering_pass,
bw_compiler=aot_eager_autobucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reason you have such if-else depending on different backends is purely because of API limitation?
Since we are always applying fx graph passes, I somehow thought there'd be a way to unify the passes UX and just use different backends.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's because aot_eager & inductor handles the pass differently.... I'm not really sure if there is a way to unify them, but that would be sth very nice to have. Basically aot_eager registers the pass as a customized compiler backend on top of aot_eager, and run it with fwd_compiler & bwd_compiler. inductor hooks the pass into post_grad_pass, and manipulate the graph traced in fx-level before lowering it to inductor IRs.

cc. @ezyang

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My mental model was

  1. forward graph capture
  2. joint graph generation
  3. joint graph passes
  4. fw / bw graph partitioning
  5. fw / bw graph passes
  6. inductor lowering & fusion
  7. inductor passes
  8. codegen

I feel aot_eager and inductor share this up to step 5, and the bucketing passes at step 5 (AC passes at step 3?), so theoretically they can be combined?

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

that was what angela's pr doing: manipulate post_grad graph, and run code eagerly: #1785. However, there was some performance regression(detail) that we didn't figure out the reason, and that's why we are using a customized backend instead for aot_eager 😢


def inductor_autobucketing_reordering_pass(
gm: torch.fx.Graph,
) -> torch.fx.GraphModule:
return schedule_overlap_bucketing(gm.owning_module)

dist_opts.insert_overlap_deps = True
torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = (
inductor_autobucketing_reordering_pass
)
else:
raise ValueError(
f"Unsupported backend {compile_config.backend} for auto_bucketing pass"
)
logger.info("Auto bucketing pass is applied")

elif compile_config.graph_passes == "transformer_block_bucketing":
# Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend
# The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487
from functools import partial

from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend
from torch._inductor.fx_passes.overlap_manual_scheduling import (
manual_overlap_bucketing,
)

torch._inductor.config.allow_buffer_reuse = False
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what happens by default?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor?

In fact, I have this confusion for all other torch._inductor fields.

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the passes live in torch/_inductor/fx_passes/ folder. It is a bit counter-intuitive that fx graph passes lives under _inductor..... But because of some legacy reasons that the pass is originally post-grad passes in inductor instead of for aot_eager fx pass. That's why you see these configs have torch._inductor fields -- They are controlling the pass via inductor's config.

manual_overlap_bucketing = partial(
manual_overlap_bucketing,
module_bucket_plans=fsdp_manual_buckets,
)

if compile_config.backend == "aot_eager":

def aot_eager_transformer_block_bucketing_reordering_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
manual_overlap_bucketing(gm, insert_overlap_deps=False)
return gm

backend = aot_autograd_backend(
fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass,
keep_inference_input_mutations=True,
)
elif compile_config.backend == "inductor":

def inductor_transformer_block_bucketing_reordering_pass(
gm: torch.fx.Graph,
) -> torch.fx.GraphModule:
return manual_overlap_bucketing(
gm.owning_module, insert_overlap_deps=True
)

torch._inductor.config.reorder_for_peak_memory = False
torch._inductor.config.reorder_for_compute_comm_overlap = False
torch._inductor.config.post_grad_custom_post_pass = (
inductor_transformer_block_bucketing_reordering_pass
)
else:
raise ValueError(
f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass"
)
logger.info("Transformer block bucketing pass is applied")

else:
raise AssertionError(f"Unsupported customized backend: {backend_name}")
logger.info("No bucketing and overlapping pass is applied")

# Apply activation checkpointing on joint graph before partitioner
def joint_ac_pass(
gm: torch.fx.GraphModule, example_inputs: Any
) -> torch.fx.GraphModule:
Expand Down
36 changes: 31 additions & 5 deletions torchtitan/experiments/simple_fsdp/deepseek_v3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,35 @@
)
from torchtitan.tools.logging import logger

from ..backend import get_compile_backend
from ..backend import get_compile_backend_with_passes

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy


def get_transformer_block_buckets(model) -> list[list[str] | str]:
module_list = [
model.tok_embeddings,
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
# [TODO](ruisizhang123) add EP support for transformer block bucketing
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
else:
result.append(module_to_fqn_mapping.get(m, None))
return result

module_to_name = {m: n for n, m in model.named_modules()}
module_fqns = convert_modules_to_fqns(module_list, module_to_name)
return module_fqns


# Adapted from llama4/infra/parallelize.py
def parallelize_deepseekv3(
model: nn.Module,
Expand Down Expand Up @@ -177,13 +202,14 @@ def parallelize_deepseekv3(
f"Invalid reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
backend = get_compile_backend_with_passes(
job_config.compile,
fsdp_reshard_after_forward,
get_transformer_block_buckets(model),
)
model = torch.compile(
model,
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
backend=backend,
fullgraph=True,
)

Expand Down
8 changes: 6 additions & 2 deletions torchtitan/experiments/simple_fsdp/job_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
# LICENSE file in the root directory of this source tree.

from dataclasses import dataclass, field
from typing import Literal


@dataclass
class Compile:
model_backend_override: str | None = None
"""Override backend to compile in simplefsdp. Additional backend includes aot_eager_autobucketing"""
graph_passes: Literal["auto_bucketing", "transformer_block_bucketing", None] = None
"""
Bucketing and overlapping passes in simplefsdp. Additional passes include:
auto_bucketing, transformer_block_bucketing
"""


@dataclass
Expand Down
34 changes: 29 additions & 5 deletions torchtitan/experiments/simple_fsdp/llama3/parallelize.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from torchtitan.models.llama3.infra.parallelize import apply_tp
from torchtitan.tools.logging import logger

from ..backend import get_compile_backend
from ..backend import get_compile_backend_with_passes

from ..simple_fsdp import data_parallel, MixedPrecisionPolicy

Expand All @@ -33,6 +33,29 @@
}


def get_transformer_block_buckets(model) -> list[list[str] | str]:
module_list = [
model.tok_embeddings,
[model.norm, model.output],
]
for layer_id, transformer_block in model.layers.items():
module_list.append(transformer_block)

def convert_modules_to_fqns(modules, module_to_fqn_mapping):
"""Convert a (possibly nested) list of modules to FQN strings."""
result = []
for m in modules:
if isinstance(m, list):
result.append(convert_modules_to_fqns(m, module_to_fqn_mapping))
else:
result.append(module_to_fqn_mapping.get(m, None))
return result

module_to_name = {m: n for n, m in model.named_modules()}
module_fqns = convert_modules_to_fqns(module_list, module_to_name)
return module_fqns


def parallelize_llama(
model: nn.Module,
parallel_dims: ParallelDims,
Expand Down Expand Up @@ -139,13 +162,14 @@ def parallelize_llama(
f"Invalid fsdp_reshard_after_forward_policy: {job_config.parallelism.fsdp_reshard_after_forward}."
)

backend = (
getattr(job_config.compile, "model_backend_override", None)
or job_config.compile.backend
backend = get_compile_backend_with_passes(
job_config.compile,
fsdp_reshard_after_forward,
get_transformer_block_buckets(model),
)
model = torch.compile(
model,
backend=get_compile_backend(backend, fsdp_reshard_after_forward),
backend=backend,
fullgraph=True,
)

Expand Down
20 changes: 17 additions & 3 deletions torchtitan/experiments/simple_fsdp/tests/integration_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,25 @@ def build_simple_fsdp_test_list() -> list[OverrideDefinitions]:
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.model_backend_override aot_eager_autobucketing",
"--compile.backend aot_eager",
"--compile.graph_passes auto_bucketing",
],
],
"1D+aot_eager_autobucketing",
"1d_aot_eager_autobucketing",
"1D+autobucketing",
"1d_autobucketing",
),
OverrideDefinitions(
[
[
"--model.name simple_fsdp.llama3",
"--compile.enable",
"--job.custom_config_module=torchtitan.experiments.simple_fsdp.job_config",
"--compile.backend aot_eager",
"--compile.graph_passes transformer_block_bucketing",
],
],
"1D+transformer_block_bucketing",
"1d_transformer_block_bucketing",
),
OverrideDefinitions(
[
Expand Down
Loading