@@ -23,76 +23,27 @@ def __init__(
2323 ):
2424 super ().__init__ ()
2525 self .num_experts = num_experts
26- self .w1 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
27- self .w2 = nn .Parameter (torch .empty (num_experts , hidden_dim , dim ))
28- self .w3 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
26+ # Combine w1 and w3 into a single tensor to perform so we can combine
27+ # `x @ w1` and `x @ w3` into a single grouped mm.
28+ self .w13 = nn .Parameter (torch .empty (num_experts , hidden_dim , dim * 2 ))
29+ self .w2 = nn .Parameter (torch .empty (num_experts , dim , hidden_dim ))
2930 self .use_grouped_mm = use_grouped_mm
3031
3132 def forward (
3233 self ,
3334 x : torch .Tensor ,
3435 num_tokens_per_expert : torch .Tensor | None = None ,
3536 ) -> torch .Tensor :
36- if self .use_grouped_mm :
37- return GroupedExperts ._run_experts_grouped_mm (
38- self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
39- )
40- else :
41- return GroupedExperts ._run_experts_for_loop (
42- self .w1 , self .w2 , self .w3 , x , num_tokens_per_expert
43- )
44-
45- # TODO: keeping this for-loop implementation for comparison
46- # and readability, may remove later
47- @expert_parallel
48- @staticmethod
49- def _run_experts_for_loop (
50- w1 : torch .Tensor ,
51- w2 : torch .Tensor ,
52- w3 : torch .Tensor ,
53- x : torch .Tensor ,
54- num_tokens_per_expert : torch .Tensor | None = None ,
55- ) -> torch .Tensor :
56- if num_tokens_per_expert is not None :
57- # NOTE: this would incur a synchronization between device and host
58- num_tokens_per_expert = num_tokens_per_expert .tolist ()
59-
60- # side-effect code due to the usage of generate_permute_indices
61- num_padding = x .shape [0 ] - sum (num_tokens_per_expert )
62-
63- # a tuple of tensors indexed by experts
64- # each with shape (tokens_per_expert(varying), dim)
65- x = torch .split (
66- x [: sum (num_tokens_per_expert )],
67- split_size_or_sections = num_tokens_per_expert ,
68- dim = 0 ,
69- )
70- out_experts_splits = []
71- for expert_idx , x_expert in enumerate (x ):
72- h = F .silu (torch .matmul (x_expert , w1 [expert_idx ]))
73- h = h * torch .matmul (x_expert , w3 [expert_idx ])
74- h = torch .matmul (h , w2 [expert_idx ])
75- # h shape (tokens_per_expert(varying), dim)
76- out_experts_splits .append (h )
77- out = torch .cat (out_experts_splits , dim = 0 )
78-
79- # side-effect code due to the usage of generate_permute_indices
80- out = torch .vstack ((out , out .new_zeros ((num_padding , out .shape [- 1 ]))))
81- else :
82- # x shape (num_experts, tokens_per_expert, dim)
83- h = F .silu (torch .bmm (x , w1 ))
84- h = h * torch .bmm (x , w3 )
85- # out shape (num_experts, tokens_per_expert, dim)
86- out = torch .bmm (h , w2 )
37+ return GroupedExperts ._run_experts_grouped_mm (
38+ self .w13 , self .w2 , x , num_tokens_per_expert
39+ )
8740
88- return out
8941
9042 @expert_parallel
9143 @staticmethod
9244 def _run_experts_grouped_mm (
93- w1 : torch .Tensor ,
45+ w13 : torch .Tensor ,
9446 w2 : torch .Tensor ,
95- w3 : torch .Tensor ,
9647 x : torch .Tensor ,
9748 num_tokens_per_expert : torch .Tensor | None = None ,
9849 ) -> torch .Tensor :
@@ -105,16 +56,14 @@ def _run_experts_grouped_mm(
10556 # fall back to regular bmm between 3D tensors
10657 assert x .dim () == 3
10758
108- h = F .silu (torch ._grouped_mm (x .bfloat16 (), w1 .bfloat16 (), offs = offsets ))
109- h = h * torch ._grouped_mm (x .bfloat16 (), w3 .bfloat16 (), offs = offsets )
110- out = torch ._grouped_mm (h , w2 .bfloat16 (), offs = offsets ).type_as (x )
111-
59+ x1 , x3 = torch ._grouped_mm (x , w13 .transpose (- 2 , - 1 ), offs = offsets ).chunk (2 , dim = - 1 )
60+ y = F .silu (x1 ) * x3
61+ out = torch ._grouped_mm (y , w2 .transpose (- 2 , - 1 ), offs = offsets ).type_as (x )
11262 return out
11363
11464 def init_weights (self , init_std : float ):
115- nn .init .trunc_normal_ (self .w1 , mean = 0.0 , std = 0.02 )
65+ nn .init .trunc_normal_ (self .w13 , mean = 0.0 , std = 0.02 )
11666 nn .init .trunc_normal_ (self .w2 , mean = 0.0 , std = init_std )
117- nn .init .trunc_normal_ (self .w3 , mean = 0.0 , std = init_std )
11867
11968
12069class TokenChoiceTopKRouter (nn .Module ):
@@ -299,7 +248,8 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
299248
300249 # shared expert
301250 if self .shared_expert is not None :
302- out = self .shared_expert (x .reshape (1 , bs * slen , dim )).reshape (
251+ out = self .shared_expert (x .reshape (1 , bs * slen , dim ))
252+ out = out .reshape (
303253 bs * slen , dim
304254 )
305255 else :
0 commit comments