Skip to content

Commit b37d043

Browse files
author
Boyko Borisov
committed
integrate with triteia sbmm
1 parent 4517c6b commit b37d043

File tree

4 files changed

+643
-38
lines changed

4 files changed

+643
-38
lines changed

scratchpad/nn/models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77

88
_GENERATION_MODELS = {
99
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
10+
"LlamaNaiveMoEForCausalLM": ("llama_naive_moe", "LlamaNaiveMoEForCausalLM"),
1011
"LlamaMoEForCausalLM": ("llama_moe", "LlamaMoEForCausalLM")
1112
}
1213

scratchpad/nn/models/llama_moe.py

Lines changed: 68 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from scratchpad.scheduler.schedule_batch import global_args
2828
from scratchpad.model_executor.forward_info import ForwardBatch
2929
from triteia.python.nn.linear import sparse_low_precision_linear
30+
from triteia.python.ops.matmul.sbmm import sbmm_4bit_2_4_native, sbmm_4bit_2_4_multilaunch, sbmm_4bit_2_4_forloop
3031

3132
class LlamaMLP(nn.Module):
3233
def __init__(
@@ -65,37 +66,72 @@ def forward(self, x):
6566
x, _ = self.down_proj(x)
6667
return x
6768

69+
class LLamaSBmm(nn.Module):
70+
def __init__(self, num_experts, infeatures, outfeatures, sbmm_type="naive", groupsize=-1):
71+
super().__init__()
72+
if groupsize == -1:
73+
groupsize = infeatures
74+
self.infeatures = infeatures
75+
self.outfeatures = outfeatures
76+
self.groupsize = groupsize
77+
self.qweight = nn.Parameter(torch.empty((num_experts, self.infeatures // 32, self.outfeatures * 16 // 8), dtype=torch.int32), False)
78+
self.meta = nn.Parameter(torch.empty((num_experts, self.outfeatures, self.infeatures // 16), dtype=torch.int16), False)
79+
self.scales = nn.Parameter(torch.empty((num_experts, self.infeatures // groupsize, self.outfeatures), dtype=torch.float16), False)
80+
self.workspace = nn.Parameter(torch.zeros(num_experts, self.outfeatures // 128 * 16, dtype=torch.int32), False)
81+
if sbmm_type == "naive":
82+
self.sbmm_func = sbmm_4bit_2_4_native
83+
elif sbmm_type == "multilaunch":
84+
self.sbmm_func = sbmm_4bit_2_4_multilaunch
85+
elif sbmm_type == "forloop":
86+
self.sbmm_func = sbmm_4bit_2_4_forloop
87+
else:
88+
raise NotImplementedError
89+
90+
def forward(self, x, indices):
91+
return self.sbmm_func(
92+
qweights=self.qweight.data,
93+
xs=x,
94+
metas=self.meta.data,
95+
ss=self.scales.data,
96+
indices=indices)
97+
6898

6999
class LlamaCompressedMLP(nn.Module):
70100
def __init__(
71101
self,
72102
hidden_size: int,
73103
intermediate_size: int,
74104
hidden_act: str,
105+
num_experts: int,
106+
sbmm_type: str,
75107
quant_config: Optional[QuantizationConfig] = None,
76108
prefix: str = "",
77109
) -> None:
78110
super().__init__()
79111
self.intermediate_size = intermediate_size
80112
self.hidden_size = hidden_size
81-
self.gate_up_proj = sparse_low_precision_linear(
82-
hidden_size,
83-
intermediate_size * 2,
113+
self.gate_up_proj = LLamaSBmm(
114+
num_experts=num_experts,
115+
infeatures=hidden_size,
116+
outfeatures=intermediate_size * 2,
117+
sbmm_type=sbmm_type,
84118
)
85-
self.down_proj = sparse_low_precision_linear(
86-
intermediate_size,
87-
hidden_size,
119+
self.down_proj = LLamaSBmm(
120+
num_experts=num_experts,
121+
infeatures=intermediate_size,
122+
outfeatures=hidden_size,
123+
sbmm_type=sbmm_type,
88124
)
89125

90-
def forward(self, x):
91-
assert not x.isnan().any()
92-
gate_up = self.gate_up_proj(x)
93-
assert not gate_up.isnan().any()
126+
def forward(self, x, indices):
127+
# assert not x.isnan().any()
128+
gate_up = self.gate_up_proj(x, indices)
129+
# assert not gate_up.isnan().any()
94130
d = x.shape[-1] // 2
95131
x = F.silu(x[..., :d]) * x[..., d:]
96-
assert not x.isnan().any()
97-
x = self.down_proj(x)
98-
assert not x.isnan().any()
132+
# assert not x.isnan().any()
133+
x = self.down_proj(x, indices)
134+
# assert not x.isnan().any()
99135
return x
100136

101137
class LlamaMoE(nn.Module):
@@ -106,6 +142,7 @@ def __init__(
106142
hidden_act: str,
107143
num_experts: int,
108144
experts_per_token: int,
145+
sbmm_type: str,
109146
quant_config: Optional[QuantizationConfig] = None,
110147
prefix: str = "",
111148
) -> None:
@@ -119,49 +156,42 @@ def __init__(
119156
quant_config=quant_config,
120157
prefix=f"{prefix}.mlp.EXPERT_ID",
121158
)
122-
self.mlp = nn.ModuleList([
123-
LlamaCompressedMLP(
159+
160+
self.mlp = LlamaCompressedMLP(
161+
num_experts=num_experts,
124162
hidden_size=hidden_size,
125163
intermediate_size=intermediate_size,
126164
hidden_act=hidden_act,
127165
quant_config=quant_config,
128-
prefix=f"{prefix}.mlp.{i}"
129-
) for i in range(num_experts)
130-
])
166+
sbmm_type=sbmm_type,
167+
prefix=f"{prefix}.mlp."
168+
)
131169
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
132170

133171
def forward(self, x):
134-
135172
base_y = self.base_mlp(x)
136173
original_shape = x.shape
137174
x = x.view(1, *x.shape) if x.dim() == 2 else x
138175
batch_size, sequence_length, hidden_dim = x.shape
176+
139177
x = x.view(-1, hidden_dim)
140178
router_logits = self.gate(x)
141-
142179
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
143180
routing_weights, selected_experts = torch.topk(routing_weights, self.experts_per_token, dim=-1)
144181
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
145-
# we cast back to the input dtype
146-
routing_weights = routing_weights.to(x.dtype)
147-
assert not routing_weights.isnan().any(), "routing weights have nan"
182+
routing_weights = routing_weights.to(x.dtype).T
148183
final_hidden_states = torch.zeros(
149184
(batch_size * sequence_length, hidden_dim), dtype=x.dtype, device=x.device
150185
)
151-
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0).contiguous()
152-
153-
for expert_idx in range(self.num_experts):
154-
expert_layer = self.mlp[expert_idx]
155-
current_mask = expert_mask
156-
current_mask = current_mask[expert_idx]
157-
idx, top_x = torch.where(current_mask)
158-
current_state = x[None, top_x].reshape(-1, hidden_dim)
159-
assert not torch.isnan(current_state).any(), "current input state has nan"
160-
current_hidden_states = expert_layer(current_state)
161-
assert not torch.isnan(current_hidden_states).any(), "current hidden state has nan"
162-
current_hidden_states *= routing_weights[top_x, idx, None]
163-
if current_hidden_states.nelement() != 0:
164-
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(x.dtype))
186+
sort_selected_experts, argsort_selected_experts = torch.sort(selected_experts.T, dim=-1)
187+
for k in range(self.experts_per_token):
188+
current_selected_experts = sort_selected_experts[k]
189+
current_routing_weights = routing_weights[k].view(-1, 1)
190+
current_argsort_selected_experts = argsort_selected_experts[k]
191+
sort_x = x[current_argsort_selected_experts]
192+
current_hidden_states = self.mlp(sort_x, current_selected_experts)[current_argsort_selected_experts] * current_routing_weights
193+
final_hidden_states += current_hidden_states
194+
165195
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
166196
final_hidden_states = final_hidden_states.view(original_shape)
167197

@@ -310,6 +340,7 @@ def __init__(
310340
quant_config=quant_config,
311341
num_experts=config.num_experts,
312342
experts_per_token=config.experts_per_token,
343+
sbmm_type=config.sbmm_type,
313344
prefix=f"{prefix}.moe",
314345
)
315346
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -480,7 +511,6 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
480511
("mlp.EXPERT_ID", "base_mlp")
481512
]
482513
for name, loaded_weight in weights:
483-
# print(name)
484514
assert not loaded_weight.isnan().any()
485515
# continue
486516
if "rotary_emb.inv_freq" in name or "projector" in name:

0 commit comments

Comments
 (0)