-
Notifications
You must be signed in to change notification settings - Fork 602
[SimpleFSDP] add manual bucketing pass #1881
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,49 +8,132 @@ | |
|
|
||
| import torch | ||
| import torch._functorch.config as functorch_config | ||
| from torchtitan.tools.logging import logger | ||
|
|
||
| from .job_config import Compile as CompileConfig | ||
|
|
||
| from .reshard_after_forward import annotate_fsdp_all_gather | ||
|
|
||
|
|
||
| def get_compile_backend( | ||
| backend_name: str, fsdp_reshard_after_forward: bool | ||
| def get_compile_backend_with_passes( | ||
| compile_config: CompileConfig, | ||
| fsdp_reshard_after_forward: bool, | ||
| fsdp_manual_buckets: list[list[str] | str] | None, | ||
| ) -> callable: | ||
| # return the compile backends used in SimpleFSDP training | ||
| # Step1: check if backend_name is inside available torch.compile backends | ||
| # Step2: check if the backend_name has been registered as a customized backend | ||
| available_torch_backend = torch._dynamo.list_backends(exclude_tags=()) | ||
|
|
||
| if backend_name in available_torch_backend: | ||
| backend = torch._dynamo.lookup_backend(backend_name) | ||
| elif backend_name == "aot_eager_autobucketing": | ||
| # Perform auto optimization in aten fx-level and execute code in aot_eager backend | ||
| # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| """ | ||
| Apply compile backend and additional graph passes. | ||
| Args: | ||
| compile_config: compile configs to apply torch.compile. | ||
| fsdp_reshard_after_forward: whether to enable reshard_after_forward in SimpleFSDP, | ||
| which is implemented via a customized AC graph pass. | ||
| fsdp_manual_buckets: used in transformer_block_bucketing to define which modules should be bucketed. | ||
| Returns: | ||
| compile backend with applied graph passes. | ||
| """ | ||
| backend = torch._dynamo.lookup_backend(compile_config.backend) | ||
|
|
||
| # Apply bucketing and overlapping pass on fwd and bwd graph separately | ||
| if compile_config.graph_passes == "auto_bucketing": | ||
| # Perform auto optimization in aten fx-level and execute code in aot_eager/inductor backend | ||
| # The autobucketing logic is here: https://github.com/pytorch/pytorch/pull/163960 | ||
| from torch._inductor.config import aten_distributed_optimizations as dist_opts | ||
| from torch._inductor.fx_passes.overlap_scheduling import ( | ||
| schedule_overlap_bucketing, | ||
| ) | ||
|
|
||
| dist_opts.collective_bucketing = True | ||
| dist_opts.insert_overlap_deps = False | ||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
|
||
| def aten_autobucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| schedule_overlap_bucketing(gm) | ||
| gm.recompile() | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aten_autobucketing_reordering_pass, | ||
| bw_compiler=aten_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| if compile_config.backend == "aot_eager": | ||
| from torch._dynamo.backends.common import ( | ||
| aot_autograd as aot_autograd_backend, | ||
| ) | ||
|
|
||
| def aot_eager_autobucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| schedule_overlap_bucketing(gm) | ||
| gm.recompile() | ||
| return gm | ||
|
|
||
| dist_opts.insert_overlap_deps = False | ||
| backend = aot_autograd_backend( | ||
| fw_compiler=aot_eager_autobucketing_reordering_pass, | ||
| bw_compiler=aot_eager_autobucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif compile_config.backend == "inductor": | ||
|
|
||
| def inductor_autobucketing_reordering_pass( | ||
| gm: torch.fx.Graph, | ||
| ) -> torch.fx.GraphModule: | ||
| return schedule_overlap_bucketing(gm.owning_module) | ||
|
|
||
| dist_opts.insert_overlap_deps = True | ||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||
| torch._inductor.config.post_grad_custom_post_pass = ( | ||
| inductor_autobucketing_reordering_pass | ||
| ) | ||
| else: | ||
| raise ValueError( | ||
| f"Unsupported backend {compile_config.backend} for auto_bucketing pass" | ||
| ) | ||
| logger.info("Auto bucketing pass is applied") | ||
|
|
||
| elif compile_config.graph_passes == "transformer_block_bucketing": | ||
| # Perform manual optimization in aten fx-level and execute code in aot_eager/inductor backend | ||
| # The manualbucketing logic is here: https://github.com/pytorch/pytorch/pull/165487 | ||
| from functools import partial | ||
|
|
||
| from torch._dynamo.backends.common import aot_autograd as aot_autograd_backend | ||
| from torch._inductor.fx_passes.overlap_manual_scheduling import ( | ||
| manual_overlap_bucketing, | ||
| ) | ||
|
|
||
| torch._inductor.config.allow_buffer_reuse = False | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. what happens by default?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In bucketing, we shouldn't allow buffer reuse; otherwise newly created comm copy-in/copy-out buffers will reuse prev buffer, which messed up the copied out data value and made the loss nan.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. aren't we doing passes in fx graph / aot_eager backend? why it has anything to do with inductor? In fact, I have this confusion for all other
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. the passes live in |
||
| manual_overlap_bucketing = partial( | ||
| manual_overlap_bucketing, | ||
| module_bucket_plans=fsdp_manual_buckets, | ||
| ) | ||
|
|
||
| if compile_config.backend == "aot_eager": | ||
|
|
||
| def aot_eager_transformer_block_bucketing_reordering_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
| manual_overlap_bucketing(gm, insert_overlap_deps=False) | ||
| return gm | ||
|
|
||
| backend = aot_autograd_backend( | ||
| fw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, | ||
| bw_compiler=aot_eager_transformer_block_bucketing_reordering_pass, | ||
| keep_inference_input_mutations=True, | ||
| ) | ||
| elif compile_config.backend == "inductor": | ||
|
|
||
| def inductor_transformer_block_bucketing_reordering_pass( | ||
| gm: torch.fx.Graph, | ||
| ) -> torch.fx.GraphModule: | ||
| return manual_overlap_bucketing( | ||
| gm.owning_module, insert_overlap_deps=True | ||
| ) | ||
|
|
||
| torch._inductor.config.reorder_for_peak_memory = False | ||
| torch._inductor.config.reorder_for_compute_comm_overlap = False | ||
| torch._inductor.config.post_grad_custom_post_pass = ( | ||
| inductor_transformer_block_bucketing_reordering_pass | ||
| ) | ||
| else: | ||
ruisizhang123 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| raise ValueError( | ||
| f"Unsupported backend {compile_config.backend} for transformer_block_bucketing pass" | ||
| ) | ||
| logger.info("Transformer block bucketing pass is applied") | ||
|
|
||
| else: | ||
| raise AssertionError(f"Unsupported customized backend: {backend_name}") | ||
| logger.info("No bucketing and overlapping pass is applied") | ||
|
|
||
| # Apply activation checkpointing on joint graph before partitioner | ||
| def joint_ac_pass( | ||
| gm: torch.fx.GraphModule, example_inputs: Any | ||
| ) -> torch.fx.GraphModule: | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reason you have such if-else depending on different backends is purely because of API limitation?
Since we are always applying fx graph passes, I somehow thought there'd be a way to unify the passes UX and just use different backends.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's because aot_eager & inductor handles the pass differently.... I'm not really sure if there is a way to unify them, but that would be sth very nice to have. Basically aot_eager registers the pass as a customized compiler backend on top of aot_eager, and run it with fwd_compiler & bwd_compiler. inductor hooks the pass into post_grad_pass, and manipulate the graph traced in fx-level before lowering it to inductor IRs.
cc. @ezyang
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
My mental model was
I feel aot_eager and inductor share this up to step 5, and the bucketing passes at step 5 (AC passes at step 3?), so theoretically they can be combined?
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
that was what angela's pr doing: manipulate post_grad graph, and run code eagerly: #1785. However, there was some performance regression(detail) that we didn't figure out the reason, and that's why we are using a customized backend instead for aot_eager 😢