Skip to content

Commit ad08041

Browse files
committed
add manual bucketing
1 parent 544b443 commit ad08041

File tree

3 files changed

+695
-3
lines changed

3 files changed

+695
-3
lines changed

test/distributed/test_aten_comm_compute_reordering.py

Lines changed: 163 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010

1111
# for some reason importing functional collectives after dynamo breaks collectives handling!
1212
import torch.distributed._functional_collectives as _functional_collectives
13+
import torch.fx as fx
1314
from torch._C import FileCheck
1415
from torch._dynamo.utils import counters, same
1516
from torch._inductor.utils import run_and_get_code, run_and_get_triton_code
@@ -888,6 +889,168 @@ def func(a, b, c, d, *, ranks):
888889
self.assertTrue(same(test_out, correct))
889890

890891

892+
def get_toy_model(device_type: str):
893+
"""
894+
Helper to construct a small multi-layer ToyModel
895+
"""
896+
897+
class ToyBlock(torch.nn.Module):
898+
def __init__(self):
899+
super().__init__()
900+
self.wq = torch.nn.Linear(4, 4)
901+
self.wk = torch.nn.Linear(4, 4)
902+
self.proj = torch.nn.Linear(4, 4)
903+
904+
def forward(self, x):
905+
attn = self.wq(x) + self.wk(x)
906+
return self.proj(torch.nn.functional.relu(attn))
907+
908+
class ToyModel(torch.nn.Module):
909+
def __init__(self):
910+
super().__init__()
911+
self.layers = torch.nn.ModuleList([ToyBlock() for _ in range(2)])
912+
self.norm = torch.nn.LayerNorm(4)
913+
914+
def forward(self, x):
915+
for blk in self.layers:
916+
x = blk(x)
917+
return self.norm(x)
918+
919+
model = ToyModel().to(device_type)
920+
return model
921+
922+
923+
def apply_manual_reordering_and_get_graph(graph, module_bucket_plans, out_li) -> None:
924+
gm = graph.owning_module
925+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
926+
ManualOverlapScheduler,
927+
)
928+
929+
# Read config values, only pass non-None values to use function defaults
930+
kwargs: dict[str, object] = {}
931+
for node in list(gm.graph.nodes):
932+
if (
933+
node.name == "all_gather_into_tensor"
934+
or node.name == "all_gather_into_tensor_1"
935+
or node.name == "wait_tensor"
936+
or node.name == "wait_tensor_1"
937+
):
938+
node.meta["nn_module_stack"] = {"test": ["module_1", ""]}
939+
if (
940+
node.name == "all_gather_into_tensor_2"
941+
or node.name == "all_gather_into_tensor_3"
942+
or node.name == "wait_tensor_2"
943+
or node.name == "wait_tensor_3"
944+
):
945+
node.meta["nn_module_stack"] = {"test": ["module_2", ""]}
946+
947+
overlapped_gm = ManualOverlapScheduler(gm, module_bucket_plans).run()
948+
overlapped_gm.graph.lint()
949+
out_li.append(overlapped_gm.graph)
950+
951+
952+
def run_and_get_manual_aten_graph(fn, *inputs):
953+
li = []
954+
apply = functools.partial(
955+
apply_manual_reordering_and_get_graph,
956+
module_bucket_plans=["module_1", "module_2"],
957+
out_li=li,
958+
)
959+
with torch._inductor.config.patch(post_grad_custom_post_pass=apply):
960+
out = fn(*inputs)
961+
962+
return out, li[0]
963+
964+
965+
class TestManualOverlapBucketing(TestComputeCommReorderingMultiProc):
966+
"""
967+
Tests for manual overlap scheduling and subgraph utilities.
968+
"""
969+
970+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
971+
def test_make_graph_view_and_get_subgraph_by_path(self):
972+
from torch._inductor.fx_passes.overlap_manual_scheduling import (
973+
get_subgraph_by_path,
974+
make_graph_view,
975+
)
976+
977+
model = get_toy_model(device_type)
978+
gm = fx.symbolic_trace(model)
979+
graph_view = make_graph_view(gm.graph)
980+
# Fetch subgraph for first transformer layer
981+
sub_nodes = get_subgraph_by_path(graph_view, "layers.0.wq")
982+
self.assertEqual([n.name for n in sub_nodes], ["layers_0_wq"])
983+
984+
# Fetch multiple paths at once
985+
multi_nodes = get_subgraph_by_path(graph_view, ["layers.0.wq", "layers.0.proj"])
986+
self.assertEqual(
987+
[n.name for n in multi_nodes], ["layers_0_wq", "layers_0_proj"]
988+
)
989+
990+
@unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch")
991+
def test_manual_reordering_bucketing_pass(
992+
self,
993+
):
994+
def func(a, b, c, d, *, ranks):
995+
# All 4 all-gathers are independent - COULD be bucketed together
996+
ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks)
997+
ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks)
998+
ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks)
999+
ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks)
1000+
1001+
# First compute - can hide ag1 and ag2
1002+
e = a * 5 # Use a to avoid fusion
1003+
mm1 = torch.matmul(e, e.T)
1004+
1005+
# Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred)
1006+
# Use first 8x8 elements to match mm1's shape
1007+
intermediate = ag1[:8, :8] + ag2[:8, :8]
1008+
1009+
# Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4
1010+
mm2 = torch.matmul(mm1 + intermediate, c[:8])
1011+
1012+
# Use all results
1013+
result = (
1014+
ag1.sum() * 1.1
1015+
+ ag2.sum() * 1.2
1016+
+ ag3.sum() * 1.3
1017+
+ ag4.sum() * 1.4
1018+
+ mm1.sum()
1019+
+ mm2.sum()
1020+
)
1021+
return result
1022+
1023+
with _dynamo_dist_per_rank_init(
1024+
self.rank,
1025+
self.world_size,
1026+
self.backend(device_type),
1027+
fake_pg=not at_least_x_gpu(2),
1028+
):
1029+
a = torch.ones(8, 8, dtype=torch.float, device=device_type)
1030+
b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2
1031+
c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3
1032+
d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4
1033+
ranks = list(range(self.world_size))
1034+
1035+
func_c = functools.partial(func, ranks=ranks)
1036+
compiled = torch.compile(func_c)
1037+
out, aten_graph = run_and_get_manual_aten_graph(compiled, a, b, c, d)
1038+
1039+
(
1040+
FileCheck()
1041+
.check("_pre_bucket_all_gather")
1042+
.check("all_gather_into_tensor_out")
1043+
.check("_pre_bucket_all_gather_1")
1044+
.check("all_gather_into_tensor_out_1")
1045+
.check("wait_tensor_4")
1046+
.check("wait_tensor_5")
1047+
.run(str(aten_graph))
1048+
)
1049+
1050+
correct = func(a, b, c, d, ranks=ranks)
1051+
self.assertTrue(same(out, correct))
1052+
1053+
8911054
if __name__ == "__main__":
8921055
from torch._dynamo.test_case import run_tests
8931056

torch/_inductor/fx_passes/bucketing.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -117,9 +117,9 @@ def bucket_reduce_scatter(
117117

118118

119119
def is_all_gather_into_tensor(node: torch.fx.Node) -> bool: # type: ignore[arg-type]
120-
return (
121-
node.op == "call_function"
122-
and node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
120+
return node.op == "call_function" and (
121+
node.target == torch.ops._c10d_functional.all_gather_into_tensor.default
122+
or node.target == torch.ops._c10d_functional.all_gather_into_tensor_out.default
123123
)
124124

125125

0 commit comments

Comments
 (0)