|
10 | 10 |
|
11 | 11 | # for some reason importing functional collectives after dynamo breaks collectives handling! |
12 | 12 | import torch.distributed._functional_collectives as _functional_collectives |
| 13 | +import torch.fx as fx |
13 | 14 | from torch._C import FileCheck |
14 | 15 | from torch._dynamo.utils import counters, same |
15 | 16 | from torch._inductor.utils import run_and_get_code, run_and_get_triton_code |
@@ -888,6 +889,168 @@ def func(a, b, c, d, *, ranks): |
888 | 889 | self.assertTrue(same(test_out, correct)) |
889 | 890 |
|
890 | 891 |
|
| 892 | +def get_toy_model(device_type: str): |
| 893 | + """ |
| 894 | + Helper to construct a small multi-layer ToyModel |
| 895 | + """ |
| 896 | + |
| 897 | + class ToyBlock(torch.nn.Module): |
| 898 | + def __init__(self): |
| 899 | + super().__init__() |
| 900 | + self.wq = torch.nn.Linear(4, 4) |
| 901 | + self.wk = torch.nn.Linear(4, 4) |
| 902 | + self.proj = torch.nn.Linear(4, 4) |
| 903 | + |
| 904 | + def forward(self, x): |
| 905 | + attn = self.wq(x) + self.wk(x) |
| 906 | + return self.proj(torch.nn.functional.relu(attn)) |
| 907 | + |
| 908 | + class ToyModel(torch.nn.Module): |
| 909 | + def __init__(self): |
| 910 | + super().__init__() |
| 911 | + self.layers = torch.nn.ModuleList([ToyBlock() for _ in range(2)]) |
| 912 | + self.norm = torch.nn.LayerNorm(4) |
| 913 | + |
| 914 | + def forward(self, x): |
| 915 | + for blk in self.layers: |
| 916 | + x = blk(x) |
| 917 | + return self.norm(x) |
| 918 | + |
| 919 | + model = ToyModel().to(device_type) |
| 920 | + return model |
| 921 | + |
| 922 | + |
| 923 | +def apply_manual_reordering_and_get_graph(graph, module_bucket_plans, out_li) -> None: |
| 924 | + gm = graph.owning_module |
| 925 | + from torch._inductor.fx_passes.overlap_manual_scheduling import ( |
| 926 | + ManualOverlapScheduler, |
| 927 | + ) |
| 928 | + |
| 929 | + # Read config values, only pass non-None values to use function defaults |
| 930 | + kwargs: dict[str, object] = {} |
| 931 | + for node in list(gm.graph.nodes): |
| 932 | + if ( |
| 933 | + node.name == "all_gather_into_tensor" |
| 934 | + or node.name == "all_gather_into_tensor_1" |
| 935 | + or node.name == "wait_tensor" |
| 936 | + or node.name == "wait_tensor_1" |
| 937 | + ): |
| 938 | + node.meta["nn_module_stack"] = {"test": ["module_1", ""]} |
| 939 | + if ( |
| 940 | + node.name == "all_gather_into_tensor_2" |
| 941 | + or node.name == "all_gather_into_tensor_3" |
| 942 | + or node.name == "wait_tensor_2" |
| 943 | + or node.name == "wait_tensor_3" |
| 944 | + ): |
| 945 | + node.meta["nn_module_stack"] = {"test": ["module_2", ""]} |
| 946 | + |
| 947 | + overlapped_gm = ManualOverlapScheduler(gm, module_bucket_plans).run() |
| 948 | + overlapped_gm.graph.lint() |
| 949 | + out_li.append(overlapped_gm.graph) |
| 950 | + |
| 951 | + |
| 952 | +def run_and_get_manual_aten_graph(fn, *inputs): |
| 953 | + li = [] |
| 954 | + apply = functools.partial( |
| 955 | + apply_manual_reordering_and_get_graph, |
| 956 | + module_bucket_plans=["module_1", "module_2"], |
| 957 | + out_li=li, |
| 958 | + ) |
| 959 | + with torch._inductor.config.patch(post_grad_custom_post_pass=apply): |
| 960 | + out = fn(*inputs) |
| 961 | + |
| 962 | + return out, li[0] |
| 963 | + |
| 964 | + |
| 965 | +class TestManualOverlapBucketing(TestComputeCommReorderingMultiProc): |
| 966 | + """ |
| 967 | + Tests for manual overlap scheduling and subgraph utilities. |
| 968 | + """ |
| 969 | + |
| 970 | + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
| 971 | + def test_make_graph_view_and_get_subgraph_by_path(self): |
| 972 | + from torch._inductor.fx_passes.overlap_manual_scheduling import ( |
| 973 | + get_subgraph_by_path, |
| 974 | + make_graph_view, |
| 975 | + ) |
| 976 | + |
| 977 | + model = get_toy_model(device_type) |
| 978 | + gm = fx.symbolic_trace(model) |
| 979 | + graph_view = make_graph_view(gm.graph) |
| 980 | + # Fetch subgraph for first transformer layer |
| 981 | + sub_nodes = get_subgraph_by_path(graph_view, "layers.0.wq") |
| 982 | + self.assertEqual([n.name for n in sub_nodes], ["layers_0_wq"]) |
| 983 | + |
| 984 | + # Fetch multiple paths at once |
| 985 | + multi_nodes = get_subgraph_by_path(graph_view, ["layers.0.wq", "layers.0.proj"]) |
| 986 | + self.assertEqual( |
| 987 | + [n.name for n in multi_nodes], ["layers_0_wq", "layers_0_proj"] |
| 988 | + ) |
| 989 | + |
| 990 | + @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") |
| 991 | + def test_manual_reordering_bucketing_pass( |
| 992 | + self, |
| 993 | + ): |
| 994 | + def func(a, b, c, d, *, ranks): |
| 995 | + # All 4 all-gathers are independent - COULD be bucketed together |
| 996 | + ag1 = _functional_collectives.all_gather_tensor(a, 0, ranks) |
| 997 | + ag2 = _functional_collectives.all_gather_tensor(b, 0, ranks) |
| 998 | + ag3 = _functional_collectives.all_gather_tensor(c[:4], 0, ranks) |
| 999 | + ag4 = _functional_collectives.all_gather_tensor(d[:4], 0, ranks) |
| 1000 | + |
| 1001 | + # First compute - can hide ag1 and ag2 |
| 1002 | + e = a * 5 # Use a to avoid fusion |
| 1003 | + mm1 = torch.matmul(e, e.T) |
| 1004 | + |
| 1005 | + # Force ag1/ag2 to complete before mm2 (but ag3/ag4 can still be deferred) |
| 1006 | + # Use first 8x8 elements to match mm1's shape |
| 1007 | + intermediate = ag1[:8, :8] + ag2[:8, :8] |
| 1008 | + |
| 1009 | + # Second compute - depends on ag1/ag2 through intermediate, can hide ag3/ag4 |
| 1010 | + mm2 = torch.matmul(mm1 + intermediate, c[:8]) |
| 1011 | + |
| 1012 | + # Use all results |
| 1013 | + result = ( |
| 1014 | + ag1.sum() * 1.1 |
| 1015 | + + ag2.sum() * 1.2 |
| 1016 | + + ag3.sum() * 1.3 |
| 1017 | + + ag4.sum() * 1.4 |
| 1018 | + + mm1.sum() |
| 1019 | + + mm2.sum() |
| 1020 | + ) |
| 1021 | + return result |
| 1022 | + |
| 1023 | + with _dynamo_dist_per_rank_init( |
| 1024 | + self.rank, |
| 1025 | + self.world_size, |
| 1026 | + self.backend(device_type), |
| 1027 | + fake_pg=not at_least_x_gpu(2), |
| 1028 | + ): |
| 1029 | + a = torch.ones(8, 8, dtype=torch.float, device=device_type) |
| 1030 | + b = torch.ones(8, 8, dtype=torch.float, device=device_type) * 2 |
| 1031 | + c = torch.ones(8, 8, dtype=torch.float, device=device_type) * 3 |
| 1032 | + d = torch.ones(8, 8, dtype=torch.float, device=device_type) * 4 |
| 1033 | + ranks = list(range(self.world_size)) |
| 1034 | + |
| 1035 | + func_c = functools.partial(func, ranks=ranks) |
| 1036 | + compiled = torch.compile(func_c) |
| 1037 | + out, aten_graph = run_and_get_manual_aten_graph(compiled, a, b, c, d) |
| 1038 | + |
| 1039 | + ( |
| 1040 | + FileCheck() |
| 1041 | + .check("_pre_bucket_all_gather") |
| 1042 | + .check("all_gather_into_tensor_out") |
| 1043 | + .check("_pre_bucket_all_gather_1") |
| 1044 | + .check("all_gather_into_tensor_out_1") |
| 1045 | + .check("wait_tensor_4") |
| 1046 | + .check("wait_tensor_5") |
| 1047 | + .run(str(aten_graph)) |
| 1048 | + ) |
| 1049 | + |
| 1050 | + correct = func(a, b, c, d, ranks=ranks) |
| 1051 | + self.assertTrue(same(out, correct)) |
| 1052 | + |
| 1053 | + |
891 | 1054 | if __name__ == "__main__": |
892 | 1055 | from torch._dynamo.test_case import run_tests |
893 | 1056 |
|
|
0 commit comments