-
Notifications
You must be signed in to change notification settings - Fork 25.9k
[simplefsdp] add manual bucketing pass #165487
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/165487
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New FailuresAs of commit 7bf56fc with merge base 87d17e9 ( NEW FAILURES - The following jobs have failed:
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
dfb1e41 to
d94f913
Compare
083bd96 to
028bf2b
Compare
028bf2b to
a4d1ed8
Compare
a4d1ed8 to
f31507d
Compare
|
Can't land this without some tests. Would it be better for this to live in torchtitan? I can't tell... but since simple fsdp is in torchtitan and not in core, maybe it should live there! |
|
|
||
| def make_graph_view(graph: fx.Graph) -> Container: | ||
| """ | ||
| Code from: https://github.com/meta-pytorch/autoparallel/pull/158 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also have Claude generate some simple unit tests for it
| The manual overlapping consists of two steps: | ||
| Step 1: bucket all-gather/reduce-scatter in each module in module_bucket_plans | ||
| Step 2: reorder all-gather to overlap with last module_bucket & | ||
| reorder reduce-scatter to overlap with next module_bucket |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The reordering here is just the SimpleFSDP strategy that you used to have in Inductor IR, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes!
This is the starting of manual optimization pass that is independent of simplefsdp, which I think should live in core. Currently, it uses similar reordering strategy, but I have a list of todo items to make it general here: pytorch/torchtitan#1881. I will add them in followup PRs. I will add tests to it to make sure it can live in core :) |
|
Well, @eellison has the really generic one, so I am not sure how much I want to push on this outside of simple fsdp generated all gathers ;) |
hmmm a better way to put it would be: you cannot guarantee this automated overlapping would squeeze the best perf out of box. In some cases, users may want to control how things are bucketed & reordered either for better overlapping (they found some parts are not overlapped and want to manually do the overlapping) or for bit-wise loss equivalence (the fsdp2 bucketing thing we are comparing rn). In both cases, having a controllable overlapping pass would be helpful. |
|
Bucketing, absolutely. Overlapping, I think you potentially may need something even more manual than this. |
yes, we need more explicit overlapping API for sure. On the other hand, fsdp2's I believe this fx-level overlapping would give us more freedom to move things. This PR is just a starting point. |
| ) | ||
| from torch._inductor.fx_passes.overlap_preserving_bucketer import ( | ||
| bucket_key, | ||
| OverlapPreservingBucketer, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@eellison to comment on inheritance here; note that we have some internal code also inheriting from this class in a similar way too.
f31507d to
3eaadd0
Compare
|
ehhhh just realized it's hard to add end-to-end manual bucketing & overlapping test in pytorch since simplefsdp is in torchtitan 😅 Another reason to consider moving simplefsdp to pytorch when more features like this are coming in lolll. I guess there is not another way other than simplefsdp that [define a model & manual bucketing FQNs -> FSDP sharding -> trace full graph -> do bucketing & overlapping] |
47b1982 to
c472083
Compare
| return name | ||
|
|
||
|
|
||
| def _find_key_nodes(nodes: list[fx.Node]) -> tuple[list[fx.Node], list[fx.Node]]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
i'd appreciate a docstring at least, or a renamed function, that makes it more intuitive what this does. (what is a "key" node?)
| return root, outputs | ||
|
|
||
|
|
||
| def _make_subgraph(nodes: list[fx.Node]) -> fx.Graph: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
similar comment as above: intuitively a function named "make_subgraph" i'd expect some input arg that helps identify what part of the graph to include in the subgraph. In this case it sounds like it makes the 'subgraph containing key nodes', which, is still opaque to me.
| Get subgraph by path(s). | ||
| Args: | ||
| graph_view (object): Root graph view object. | ||
| paths (str or list of str): Path(s) to subgraph. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
any requirements here? overlapping paths ok? disjoint paths ok?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is no requirement. this function will collect all of the nodes as long as they belong to the path(s). This is for the case that we may want to put nodes from multiple paths in one bucket. The disjoint graph assertion bucket should happen here to ensure two buckets have no overlap. Will add an assertion in _obtain_nodes_in_subgraph func.
| return stack == "" | ||
|
|
||
|
|
||
| def make_graph_view(graph: fx.Graph) -> Container: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
its a little odd to me that this function returns a 'Container' type, and then on the Container you can further call .graph_view() to get an fx graphmodule type. I might not be understanding the design yet, but i wonder if it would be cleaner to have
- Container class renamed to GraphView class
- functions like
get_subgraph_by_pathmoved to class methods of GraphView - make_graph_view moves to GraphView.init
- GraphView.graph_view() still doesn't make sense to me
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We are using Container class to have a hierarchical mapping between nodes & their module names. I agree it would be more intuitive to rename Container to Graph_view, will do this. We can also rename GraphView.graph_view() to sth like GraphView.obtain_subgraph()?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I realized GraphView.graph_view() is not actually used in our bucketing & reordering pass to get subgraphs. Instead, I decided to use get_subgraph_by_path func, which is more friendly to input FQNs as strings from torchtitan. Thus, I removed the legacy functions including_make_subgraph, _find_key_nodes and graph_view in the new code.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ok removing those functions will address a lot of my questions.
193315c to
ad08041
Compare
| return stack == "" | ||
|
|
||
|
|
||
| def make_graph_view(graph: fx.Graph) -> GraphView: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Would it make sense to put this in a differnet file?
|
|
||
| return seen_target_op == 1 | ||
|
|
||
| def _bucket_group(self, coll_nodes: list[fx.Node]) -> None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
reason you had to define this, just curious ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It is a bit different than _apply_bucket FWIW
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I wanted to add more info to bucketed nodes to help reordering: (1) tag newly bucketed AG/RS's metadata, which helps to identify FSDP related comms in reordering; (2) keep track of newly bucketed wait and its mapping to bucketed AG/RS as self.wait_to_node_map. It would help adding dependencies in reordering.
| return node.data | ||
|
|
||
|
|
||
| class ManualOverlapPreservingBucketer(OverlapPreservingBucketer): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What parts of the existing class are you using, just for my understanding?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
for OverlapPreservingBucketer, I'm only using __init__ func to get self.collective_info and self.node_ancestors. The manual bucketing logic is much simpler -- you basically put things in a bucket and call bucket helper functions.
| Scheduler that manual buckets and reorders collective nodes based on module_bucket_plans | ||
| """ | ||
|
|
||
| def __init__(self, gm: fx.GraphModule, module_bucket_plans: list[list[str] | str]): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same question here ? I guess this is heapq, mostly? anything else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, mostly helper variables (e.g., collective_info) and heapq to help reordering & bucketing.
|
|
||
|
|
||
| def _get_module_stack(node: fx.Node) -> list[tuple[str, type[Any]]]: | ||
| if node.meta.get("nn_module_stack", "") == "": |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
err why not just test if it's None?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
updated
| .replace("_modules['", "") | ||
| .replace("['", ".") | ||
| .replace("']", "") | ||
| ) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is kind of terrible; also not sure if there's a preexisting utility for this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can rewrite it with regex function that extracts MODULE inside ['MODULE']. yeah, this one is a bit fragile.
ezyang
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this has been stuck in review hell for a while, stamping to unblock, it's really low risk to put in
ad08041 to
8900629
Compare
8900629 to
7bf56fc
Compare
As titled, this PR adds manual bucketing pass to SimpleFSDP. Users will need to parse FQNs they wanted to bucket together using
module_bucket_plans. Then,_manual_bucket_collectiveswill get the node of the subgraphs correspond to eachbucket_module, and bucket bucketable (FSDP-style) AG/RS together._manual_reorder_graphreorders them for overlapping.For detailed performance, see this torchtitan PR: pytorch/torchtitan#1881.
There are a few todo items isted in torchtitan PR. Let's start with this PR that implements FSDP+TP+llama3 manual bucketing. I will fix/add the rest in follow up PRs.
cc @H-Huang @awgu @wanchaol @fegin @fduwjj @wz337 @wconstab @d4l3k @pragupta @msaroufim @dcci @voznesenskym @penguinwu @EikanWang @jgong5 @Guobing-Chen @XiaobingSuper @zhuhaozhe @blzheng @wenzhe-nrv @jiayisunx @ipiszy @chenyang78 @kadeng @muchulee8 @amjames @chauhang @aakhundov @coconutruben