Skip to content

Commit 7c327f4

Browse files
author
Boyko Borisov
committed
Adds initial code for unquantised MoE
1 parent 38a5fe4 commit 7c327f4

File tree

4 files changed

+606
-139
lines changed

4 files changed

+606
-139
lines changed

scratchpad/nn/models/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,8 @@
77

88
_GENERATION_MODELS = {
99
"LlamaForCausalLM": ("llama", "LlamaForCausalLM"),
10-
"LlamaNaiveMoEForCausalLM": ("llama_naive_moe", "LlamaNaiveMoEForCausalLM"),
10+
"LlamaNaiveQuantisedMoEForCausalLM": ("llama_naive_moe", "LlamaNaiveQuantisedMoEForCausalLM"),
11+
"LlamaQuantisedMoEForCausalLM": ("llama_quant_moe", "LlamaQuantisedMoEForCausalLM"),
1112
"LlamaMoEForCausalLM": ("llama_moe", "LlamaMoEForCausalLM")
1213
}
1314

scratchpad/nn/models/llama_moe.py

Lines changed: 27 additions & 127 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,6 @@
2626
from scratchpad.nn.utils import apply_torchao_config_
2727
from scratchpad.scheduler.schedule_batch import global_args
2828
from scratchpad.model_executor.forward_info import ForwardBatch
29-
from triteia.python.nn.linear import sparse_low_precision_linear
30-
from triteia.python.ops.matmul.sbmm import sbmm_4bit_2_4_native, sbmm_4bit_2_4_multilaunch, sbmm_4bit_2_4_forloop
3129

3230
class LlamaMLP(nn.Module):
3331
def __init__(
@@ -61,78 +59,11 @@ def __init__(
6159
self.act_fn = SiluAndMul()
6260

6361
def forward(self, x):
64-
gate_up, _ = self.gate_up_proj(x)
65-
x = self.act_fn(gate_up)
62+
x, _ = self.gate_up_proj(x)
63+
x = self.act_fn(x)
6664
x, _ = self.down_proj(x)
6765
return x
6866

69-
class LLamaSBmm(nn.Module):
70-
def __init__(self, num_experts, infeatures, outfeatures, sbmm_type="naive", groupsize=-1):
71-
super().__init__()
72-
if groupsize == -1:
73-
groupsize = infeatures
74-
self.infeatures = infeatures
75-
self.outfeatures = outfeatures
76-
self.groupsize = groupsize
77-
self.qweight = nn.Parameter(torch.empty((num_experts, self.infeatures // 32, self.outfeatures * 16 // 8), dtype=torch.int32), False)
78-
self.meta = nn.Parameter(torch.empty((num_experts, self.outfeatures, self.infeatures // 16), dtype=torch.int16), False)
79-
self.scales = nn.Parameter(torch.empty((num_experts, self.infeatures // groupsize, self.outfeatures), dtype=torch.float16), False)
80-
self.workspace = nn.Parameter(torch.zeros(num_experts, self.outfeatures // 128 * 16, dtype=torch.int32), False)
81-
if sbmm_type == "naive":
82-
self.sbmm_func = sbmm_4bit_2_4_native
83-
elif sbmm_type == "multilaunch":
84-
self.sbmm_func = sbmm_4bit_2_4_multilaunch
85-
elif sbmm_type == "forloop":
86-
self.sbmm_func = sbmm_4bit_2_4_forloop
87-
else:
88-
raise NotImplementedError
89-
90-
def forward(self, x, indices):
91-
return self.sbmm_func(
92-
qweights=self.qweight.data,
93-
xs=x,
94-
metas=self.meta.data,
95-
ss=self.scales.data,
96-
indices=indices)
97-
98-
99-
class LlamaCompressedMLP(nn.Module):
100-
def __init__(
101-
self,
102-
hidden_size: int,
103-
intermediate_size: int,
104-
hidden_act: str,
105-
num_experts: int,
106-
sbmm_type: str,
107-
quant_config: Optional[QuantizationConfig] = None,
108-
prefix: str = "",
109-
) -> None:
110-
super().__init__()
111-
self.intermediate_size = intermediate_size
112-
self.hidden_size = hidden_size
113-
self.gate_up_proj = LLamaSBmm(
114-
num_experts=num_experts,
115-
infeatures=hidden_size,
116-
outfeatures=intermediate_size * 2,
117-
sbmm_type=sbmm_type,
118-
)
119-
self.down_proj = LLamaSBmm(
120-
num_experts=num_experts,
121-
infeatures=intermediate_size,
122-
outfeatures=hidden_size,
123-
sbmm_type=sbmm_type,
124-
)
125-
126-
def forward(self, x, indices):
127-
# assert not x.isnan().any()
128-
gate_up = self.gate_up_proj(x, indices)
129-
# assert not gate_up.isnan().any()
130-
d = x.shape[-1] // 2
131-
x = F.silu(x[..., :d]) * x[..., d:]
132-
# assert not x.isnan().any()
133-
x = self.down_proj(x, indices)
134-
# assert not x.isnan().any()
135-
return x
13667

13768
class LlamaMoE(nn.Module):
13869
def __init__(
@@ -142,75 +73,58 @@ def __init__(
14273
hidden_act: str,
14374
num_experts: int,
14475
experts_per_token: int,
145-
sbmm_type: str,
14676
quant_config: Optional[QuantizationConfig] = None,
14777
prefix: str = "",
14878
) -> None:
14979
super().__init__()
15080
self.experts_per_token = experts_per_token
15181
self.num_experts = num_experts
152-
self.base_mlp = LlamaMLP(
153-
hidden_size=hidden_size,
154-
intermediate_size=intermediate_size,
155-
hidden_act=hidden_act,
156-
quant_config=quant_config,
157-
prefix=f"{prefix}.mlp.EXPERT_ID",
158-
)
159-
160-
self.mlp = LlamaCompressedMLP(
161-
num_experts=num_experts,
82+
self.mlp = nn.ModuleList([
83+
LlamaMLP(
16284
hidden_size=hidden_size,
16385
intermediate_size=intermediate_size,
16486
hidden_act=hidden_act,
16587
quant_config=quant_config,
166-
sbmm_type=sbmm_type,
167-
prefix=f"{prefix}.mlp."
168-
)
88+
prefix=f"{prefix}.mlp.{i}"
89+
) for i in range(num_experts)
90+
])
16991
self.gate = nn.Linear(hidden_size, num_experts, bias=False)
17092

17193
def forward(self, x):
172-
base_y = self.base_mlp(x)
17394
original_shape = x.shape
17495
x = x.view(1, *x.shape) if x.dim() == 2 else x
17596
batch_size, sequence_length, hidden_dim = x.shape
176-
17797
x = x.view(-1, hidden_dim)
17898
router_logits = self.gate(x)
17999
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
180100
routing_weights, selected_experts = torch.topk(routing_weights, self.experts_per_token, dim=-1)
181101
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
182-
routing_weights = routing_weights.to(x.dtype).T
102+
routing_weights = routing_weights.to(x.dtype)
183103
final_hidden_states = torch.zeros(
184104
(batch_size * sequence_length, hidden_dim), dtype=x.dtype, device=x.device
185105
)
186-
sort_selected_experts, argsort_selected_experts = torch.sort(selected_experts.T, dim=-1)
187-
for k in range(self.experts_per_token):
188-
current_selected_experts = sort_selected_experts[k]
189-
current_routing_weights = routing_weights[k].view(-1, 1)
190-
current_argsort_selected_experts = argsort_selected_experts[k]
191-
sort_x = x[current_argsort_selected_experts]
192-
current_hidden_states = self.mlp(sort_x, current_selected_experts)[current_argsort_selected_experts] * current_routing_weights
193-
final_hidden_states += current_hidden_states
194-
106+
expert_mask = torch.nn.functional.one_hot(selected_experts, num_classes=self.num_experts).permute(2, 1, 0).contiguous()
107+
108+
for expert_idx in range(self.num_experts):
109+
expert_layer = self.mlp[expert_idx]
110+
current_mask = expert_mask[expert_idx]
111+
idx, top_x = torch.where(current_mask)
112+
current_state = x[None, top_x].reshape(-1, hidden_dim)
113+
if current_state.nelement() != 0:
114+
current_hidden_states = expert_layer(current_state)
115+
current_hidden_states *= routing_weights[top_x, idx, None]
116+
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(final_hidden_states.dtype))
195117
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
196118
final_hidden_states = final_hidden_states.view(original_shape)
197-
198-
final_hidden_states = final_hidden_states.contiguous()
199-
base_y = base_y.contiguous()
200-
201119
# For debugging
202120
# assert final_hidden_states.is_contiguous(), "final_hidden_states is not contiguous"
203-
# assert base_y.is_contiguous(), "base_y is not contiguous"
204-
# assert final_hidden_states.device == base_y.device, "Tensors are on different devices"
205-
# assert final_hidden_states.dtype == base_y.dtype, "Tensors have different dtypes"
121+
# print(final_hidden_states.device)
122+
# print(final_hidden_states.shape)
123+
# print(final_hidden_states.dtype)
124+
# print(final_hidden_states)
206125
# assert not torch.isnan(final_hidden_states).any(), "NaN found in final_hidden_states"
207-
# assert not torch.isnan(base_y).any(), "NaN found in base_y"
208126
# assert not torch.isinf(final_hidden_states).any(), "Inf found in final_hidden_states"
209-
# assert not torch.isinf(base_y).any(), "Inf found in base_y"
210-
# assert final_hidden_states.shape == base_y.shape, "Tensors have different shapes"
211-
# torch.cuda.synchronize()
212-
result = final_hidden_states + base_y
213-
return result
127+
return final_hidden_states
214128

215129
class LlamaAttention(nn.Module):
216130
def __init__(
@@ -340,7 +254,6 @@ def __init__(
340254
quant_config=quant_config,
341255
num_experts=config.num_experts,
342256
experts_per_token=config.experts_per_token,
343-
sbmm_type=config.sbmm_type,
344257
prefix=f"{prefix}.moe",
345258
)
346259
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@@ -505,13 +418,9 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
505418
]
506419
params_dict = dict(self.named_parameters())
507420

508-
name_transformations = [
509-
("down_proj.0", "down_proj"),
510-
("gate_up_proj.0", "gate_up_proj"),
511-
("mlp.EXPERT_ID", "base_mlp")
512-
]
513421
for name, loaded_weight in weights:
514-
assert not loaded_weight.isnan().any()
422+
# print(name)
423+
# assert not loaded_weight.isnan().any()
515424
# continue
516425
if "rotary_emb.inv_freq" in name or "projector" in name:
517426
continue
@@ -525,28 +434,19 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
525434
for param_name, weight_name, shard_id in stacked_params_mapping:
526435
if weight_name not in name:
527436
continue
437+
print(name, name.replace(weight_name, param_name), shard_id)
528438
name = name.replace(weight_name, param_name)
529439
# Skip loading extra bias for GPTQ models.
530440
if name.endswith(".bias") and name not in params_dict:
531441
continue
532-
for transformation in name_transformations:
533-
if transformation[0] in name:
534-
name = name.replace(transformation[0], transformation[1])
535442
param = params_dict[name]
536443
weight_loader = param.weight_loader
537444
weight_loader(param, loaded_weight, shard_id)
538445
break
539446
else:
540-
if name == "lm_head.0.weight":
541-
continue
542-
if name == "model.embed_tokens.0.weight":
543-
continue
544447
# Skip loading extra bias for GPTQ models.
545448
if name.endswith(".bias") and name not in params_dict:
546449
continue
547-
for transformation in name_transformations:
548-
if transformation[0] in name:
549-
name = name.replace(transformation[0], transformation[1])
550450
param = params_dict[name]
551451
weight_loader = getattr(param, "weight_loader", default_weight_loader)
552452
weight_loader(param, loaded_weight)

scratchpad/nn/models/llama_naive_moe.py renamed to scratchpad/nn/models/llama_naive_quant_moe.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -88,14 +88,14 @@ def __init__(
8888
)
8989

9090
def forward(self, x):
91-
assert not x.isnan().any()
92-
gate_up = self.gate_up_proj(x)
93-
assert not gate_up.isnan().any()
91+
# assert not x.isnan().any()
92+
x = self.gate_up_proj(x)
93+
# assert not gate_up.isnan().any()
9494
d = x.shape[-1] // 2
9595
x = F.silu(x[..., :d]) * x[..., d:]
96-
assert not x.isnan().any()
96+
# assert not x.isnan().any()
9797
x = self.down_proj(x)
98-
assert not x.isnan().any()
98+
# assert not x.isnan().any()
9999
return x
100100

101101
class LlamaMoE(nn.Module):
@@ -148,7 +148,7 @@ def forward(self, x):
148148
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
149149
# we cast back to the input dtype
150150
routing_weights = routing_weights.to(x.dtype)
151-
assert not routing_weights.isnan().any(), "routing weights have nan"
151+
# assert not routing_weights.isnan().any(), "routing weights have nan"
152152
final_hidden_states = torch.zeros(
153153
(batch_size * sequence_length, hidden_dim), dtype=x.dtype, device=x.device
154154
)
@@ -160,11 +160,11 @@ def forward(self, x):
160160
current_mask = current_mask[expert_idx]
161161
idx, top_x = torch.where(current_mask)
162162
current_state = x[None, top_x].reshape(-1, hidden_dim)
163-
assert not torch.isnan(current_state).any(), "current input state has nan"
163+
# assert not torch.isnan(current_state).any(), "current input state has nan"
164164
current_hidden_states = expert_layer(current_state)
165-
assert not torch.isnan(current_hidden_states).any(), "current hidden state has nan"
166-
current_hidden_states *= routing_weights[top_x, idx, None]
165+
# assert not torch.isnan(current_hidden_states).any(), "current hidden state has nan"
167166
if current_hidden_states.nelement() != 0:
167+
current_hidden_states *= routing_weights[top_x, idx, None]
168168
final_hidden_states.index_add_(0, top_x, current_hidden_states.to(x.dtype))
169169
final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)
170170
final_hidden_states = final_hidden_states.view(original_shape)
@@ -394,7 +394,7 @@ def forward(
394394
return hidden_states
395395

396396

397-
class LlamaNaiveMoEForCausalLM(nn.Module):
397+
class LlamaNaiveQuantisedMoEForCausalLM(nn.Module):
398398
def __init__(
399399
self,
400400
config: LlamaConfig,
@@ -537,4 +537,4 @@ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
537537

538538

539539

540-
EntryClass = [LlamaNaiveMoEForCausalLM]
540+
EntryClass = [LlamaNaiveQuantisedMoEForCausalLM]

0 commit comments

Comments
 (0)