Skip to content

Conversation

@ruisizhang123
Copy link
Contributor

@ruisizhang123 ruisizhang123 commented Oct 14, 2025

As titled, this PR adds manual bucketing pass to SimpleFSDP. Users will need to parse FQNs they wanted to bucket together using module_bucket_plans. Then, _manual_bucket_collectives will get the node of the subgraphs correspond to each bucket_module, and bucket bucketable (FSDP-style) AG/RS together. _manual_reorder_graph reorders them for overlapping.

For detailed performance, see this torchtitan PR: pytorch/torchtitan#1881.

There are a few todo items isted in torchtitan PR. Let's start with this PR that implements FSDP+TP+llama3 manual bucketing. I will fix/add the rest in follow up PRs.

cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben

@pytorch-bot
Copy link

pytorch-bot bot commented Oct 14, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165487

Note: Links to docs will display an error until the docs builds have been completed.

❌ 3 New Failures

As of commit 7bf56fc with merge base 87d17e9 (image):

NEW FAILURES - The following jobs have failed:

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@ruisizhang123 ruisizhang123 marked this pull request as draft October 14, 2025 22:56
@ruisizhang123 ruisizhang123 added the topic: not user facing topic category label Oct 14, 2025
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 4 times, most recently from dfb1e41 to d94f913 Compare October 17, 2025 16:03
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 4 times, most recently from 083bd96 to 028bf2b Compare October 28, 2025 06:57
@ruisizhang123 ruisizhang123 changed the title [WIP] add manual bucketing pass [simplefsdp] add manual bucketing pass Oct 28, 2025
@ruisizhang123 ruisizhang123 marked this pull request as ready for review October 28, 2025 07:05
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 028bf2b to a4d1ed8 Compare October 28, 2025 15:42
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from a4d1ed8 to f31507d Compare October 28, 2025 17:45
@ezyang
Copy link
Contributor

ezyang commented Oct 29, 2025

Can't land this without some tests. Would it be better for this to live in torchtitan? I can't tell... but since simple fsdp is in torchtitan and not in core, maybe it should live there!


def make_graph_view(graph: fx.Graph) -> Container:
"""
Code from: https://github.com/meta-pytorch/autoparallel/pull/158
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we're putting this in core, let's do it for real. Dedicated module, docs, tests. cc @fmassa @eellison. I haven't used this API yet so I don't have a real world opinion on it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also have Claude generate some simple unit tests for it

The manual overlapping consists of two steps:
Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans
Step 2: reorder all-gather to overlap with last module_bucket &
reorder reduce-scatter to overlap with next module_bucket
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The reordering here is just the SimpleFSDP strategy that you used to have in Inductor IR, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes!

@ruisizhang123
Copy link
Contributor Author

Can't land this without some tests. Would it be better for this to live in torchtitan? I can't tell... but since simple fsdp is in torchtitan and not in core, maybe it should live there!

This is the starting of manual optimization pass that is independent of simplefsdp, which I think should live in core. Currently, it uses similar reordering strategy, but I have a list of todo items to make it general here: pytorch/torchtitan#1881. I will add them in followup PRs.

I will add tests to it to make sure it can live in core :)

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2025

Well, @eellison has the really generic one, so I am not sure how much I want to push on this outside of simple fsdp generated all gathers ;)

@ruisizhang123
Copy link
Contributor Author

Well, @eellison has the really generic one, so I am not sure how much I want to push on this outside of simple fsdp generated all gathers ;)

hmmm a better way to put it would be: you cannot guarantee this automated overlapping would squeeze the best perf out of box. In some cases, users may want to control how things are bucketed & reordered either for better overlapping (they found some parts are not overlapped and want to manually do the overlapping) or for bit-wise loss equivalence (the fsdp2 bucketing thing we are comparing rn).

In both cases, having a controllable overlapping pass would be helpful.

@ezyang
Copy link
Contributor

ezyang commented Oct 30, 2025

Bucketing, absolutely. Overlapping, I think you potentially may need something even more manual than this.

@ruisizhang123
Copy link
Contributor Author

Bucketing, absolutely. Overlapping, I think you potentially may need something even more manual than this.

yes, we need more explicit overlapping API for sure. On the other hand, fsdp2's fully_shard API also does similar things, where users specify which module they want to bucket, and the overlapping happens under the hood.

I believe this fx-level overlapping would give us more freedom to move things. This PR is just a starting point.

)
from torch._inductor.fx_passes.overlap_preserving_bucketer import (
bucket_key,
OverlapPreservingBucketer,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eellison to comment on inheritance here; note that we have some internal code also inheriting from this class in a similar way too.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from f31507d to 3eaadd0 Compare November 4, 2025 07:24
@pytorch-bot pytorch-bot bot added the oncall: distributed Add this issue/PR to distributed oncall triage queue label Nov 4, 2025
@ruisizhang123
Copy link
Contributor Author

ruisizhang123 commented Nov 4, 2025

ehhhh just realized it's hard to add end-to-end manual bucketing & overlapping test in pytorch since simplefsdp is in torchtitan 😅

Another reason to consider moving simplefsdp to pytorch when more features like this are coming in lolll. I guess there is not another way other than simplefsdp that [define a model & manual bucketing FQNs -> FSDP sharding -> trace full graph -> do bucketing & overlapping]

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 2 times, most recently from 47b1982 to c472083 Compare November 5, 2025 03:59
return name


def _find_key_nodes(nodes: list[fx.Node]) -> tuple[list[fx.Node], list[fx.Node]]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

i'd appreciate a docstring at least, or a renamed function, that makes it more intuitive what this does. (what is a "key" node?)

return root, outputs


def _make_subgraph(nodes: list[fx.Node]) -> fx.Graph:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

similar comment as above: intuitively a function named "make_subgraph" i'd expect some input arg that helps identify what part of the graph to include in the subgraph. In this case it sounds like it makes the 'subgraph containing key nodes', which, is still opaque to me.

Get subgraph by path(s).
Args:
graph_view (object): Root graph view object.
paths (str or list of str): Path(s) to subgraph.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any requirements here? overlapping paths ok? disjoint paths ok?

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is no requirement. this function will collect all of the nodes as long as they belong to the path(s). This is for the case that we may want to put nodes from multiple paths in one bucket. The disjoint graph assertion bucket should happen here to ensure two buckets have no overlap. Will add an assertion in _obtain_nodes_in_subgraph func.

return stack == ""


def make_graph_view(graph: fx.Graph) -> Container:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its a little odd to me that this function returns a 'Container' type, and then on the Container you can further call .graph_view() to get an fx graphmodule type. I might not be understanding the design yet, but i wonder if it would be cleaner to have

  • Container class renamed to GraphView class
  • functions like get_subgraph_by_path moved to class methods of GraphView
  • make_graph_view moves to GraphView.init
  • GraphView.graph_view() still doesn't make sense to me

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 6, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are using Container class to have a hierarchical mapping between nodes & their module names. I agree it would be more intuitive to rename Container to Graph_view, will do this. We can also rename GraphView.graph_view() to sth like GraphView.obtain_subgraph()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I realized GraphView.graph_view() is not actually used in our bucketing & reordering pass to get subgraphs. Instead, I decided to use get_subgraph_by_path func, which is more friendly to input FQNs as strings from torchtitan. Thus, I removed the legacy functions including_make_subgraph, _find_key_nodes and graph_view in the new code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok removing those functions will address a lot of my questions.

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch 2 times, most recently from 193315c to ad08041 Compare November 7, 2025 02:05
return stack == ""


def make_graph_view(graph: fx.Graph) -> GraphView:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would it make sense to put this in a differnet file?


return seen_target_op == 1

def _bucket_group(self, coll_nodes: list[fx.Node]) -> None:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reason you had to define this, just curious ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is a bit different than _apply_bucket FWIW

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wanted to add more info to bucketed nodes to help reordering: (1) tag newly bucketed AG/RS's metadata, which helps to identify FSDP related comms in reordering; (2) keep track of newly bucketed wait and its mapping to bucketed AG/RS as self.wait_to_node_map. It would help adding dependencies in reordering.

return node.data


class ManualOverlapPreservingBucketer(OverlapPreservingBucketer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What parts of the existing class are you using, just for my understanding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

for OverlapPreservingBucketer, I'm only using __init__ func to get self.collective_info and self.node_ancestors. The manual bucketing logic is much simpler -- you basically put things in a bucket and call bucket helper functions.

Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans
"""

def __init__(self, gm: fx.GraphModule, module_bucket_plans: list[list[str] | str]):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same question here ? I guess this is heapq, mostly? anything else?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, mostly helper variables (e.g., collective_info) and heapq to help reordering & bucketing.



def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]:
if node.meta.get("nn_module_stack", "") == "":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

err why not just test if it's None?

Copy link
Contributor Author

@ruisizhang123 ruisizhang123 Nov 11, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated

.replace("_modules['", "")
.replace("['", ".")
.replace("']", "")
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is kind of terrible; also not sure if there's a preexisting utility for this

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can rewrite it with regex function that extracts MODULE inside ['MODULE']. yeah, this one is a bit fragile.

Copy link
Contributor

@ezyang ezyang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this has been stuck in review hell for a while, stamping to unblock, it's really low risk to put in

@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from ad08041 to 8900629 Compare November 11, 2025 06:49
@ruisizhang123 ruisizhang123 force-pushed the ruisi/manual_bucket_pass branch from 8900629 to 7bf56fc Compare November 11, 2025 06:51
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ciflow/inductor module: inductor oncall: distributed Add this issue/PR to distributed oncall triage queue topic: not user facing topic category

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants