Skip to content

Commit ebb26f0

Browse files
authored
[gaudi] Deepseek v2 mla and add ep to unquantized moe (#3287)
Signed-off-by: Wang, Yi A <[email protected]>
1 parent 778b61c commit ebb26f0

File tree

8 files changed

+171
-107
lines changed

8 files changed

+171
-107
lines changed

Dockerfile_gaudi

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,9 +118,9 @@ ENTRYPOINT ["./entrypoint.sh"]
118118
# Final image
119119
FROM base
120120

121-
ENV HF_HUB_ENABLE_HF_TRANSFER 1
122-
ENV HABANA_VISIBLE_DEVICES all
123-
ENV OMPI_MCA_btl_vader_single_copy_mechanism NONE
121+
ENV HF_HUB_ENABLE_HF_TRANSFER=1
122+
ENV HABANA_VISIBLE_DEVICES=all
123+
ENV OMPI_MCA_btl_vader_single_copy_mechanism=NONE
124124

125125
COPY backends/gaudi/tgi-entrypoint.sh /tgi-entrypoint.sh
126126
RUN chmod +x /tgi-entrypoint.sh

backends/gaudi/server/text_generation_server/layers/moe/fp8.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,12 @@ def __init__(
5151
self.rank = weights.process_group.rank()
5252
self.ep_rank = self.rank
5353
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
54-
54+
if (n_experts + self.world_size - 1) // self.world_size < 4:
55+
self.use_ep = False
5556
if self.use_ep:
56-
n_experts = (n_experts + self.world_size - 1) // self.world_size
57-
self.ep_offset = self.ep_rank * n_experts
57+
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
58+
self.ep_offset = self.ep_rank * n_experts_per_rank
59+
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
5860
else:
5961
self.ep_offset = 0
6062

backends/gaudi/server/text_generation_server/layers/moe/unquantized.py

Lines changed: 45 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from vllm_hpu_extension.ops import VllmMixtureOfExpertsOp
88
import habana_frameworks.torch as htorch
99
import torch.nn.functional as F
10+
import os
1011

1112

1213
class UnquantizedSparseMoELayer(nn.Module):
@@ -39,23 +40,42 @@ def __init__(
3940
self.weight_block_size = weights.weights_loader.weight_block_size
4041
self.scoring_func = scoring_func
4142
self.e_score_correction_bias = e_score_correction_bias
43+
self.rank = weights.process_group.rank()
44+
self.world_size = weights.process_group.size()
45+
self.use_ep = os.getenv("USE_EXPERT_PARALLEL", "true").lower() == "true"
46+
if (n_experts + self.world_size - 1) // self.world_size < 4:
47+
self.use_ep = False
48+
if self.use_ep:
49+
n_experts_per_rank = (n_experts + self.world_size - 1) // self.world_size
50+
self.ep_offset = self.rank * n_experts_per_rank
51+
n_experts = min(n_experts_per_rank, n_experts - self.ep_offset)
52+
experts_min = self.ep_offset
53+
experts_max = self.ep_offset + n_experts - 1
54+
else:
55+
self.ep_offset = 0
56+
experts_min = 0
57+
experts_max = n_experts - 1
4258

4359
self.gate_up_proj = _load_expert_multi_weights_col(
4460
prefix=prefix,
4561
n_experts=n_experts,
4662
gate_proj_name=gate_proj_name,
4763
up_proj_name=up_proj_name,
4864
weights=weights,
65+
use_ep=self.use_ep,
66+
ep_offset=self.ep_offset,
4967
)
5068

5169
self.down_proj = _load_expert_weights_row(
5270
prefix=prefix,
5371
n_experts=n_experts,
5472
name=down_proj_name,
5573
weights=weights,
74+
use_ep=self.use_ep,
75+
ep_offset=self.ep_offset,
5676
)
5777

58-
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, 0, n_experts - 1)
78+
self.MoeOp = VllmMixtureOfExpertsOp(n_experts, experts_min, experts_max)
5979
for i in range(n_experts):
6080
self.MoeOp.w13_list[i].set_weight(self.gate_up_proj[i])
6181
self.MoeOp.w2_list[i].set_weight(self.down_proj[i])
@@ -87,12 +107,23 @@ def _load_expert_multi_weights_col(
87107
gate_proj_name: str,
88108
up_proj_name: str,
89109
weights: Weights,
110+
use_ep: bool = False,
111+
ep_offset: int = 0,
90112
) -> torch.Tensor:
91113
all_weight = None
92114
for i in range(n_experts):
93-
weight = weights.get_multi_weights_col(
94-
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
95-
)
115+
if not use_ep:
116+
weight = weights.get_multi_weights_col(
117+
[f"{prefix}.{i}.{gate_proj_name}", f"{prefix}.{i}.{up_proj_name}"], 0
118+
)
119+
else:
120+
weight = weights.get_multi_weights(
121+
[
122+
f"{prefix}.{i+ep_offset}.{gate_proj_name}",
123+
f"{prefix}.{i+ep_offset}.{up_proj_name}",
124+
],
125+
0,
126+
)
96127

97128
assert isinstance(weight, UnquantizedWeight)
98129

@@ -116,12 +147,19 @@ def _load_expert_weights_row(
116147
n_experts: int,
117148
name: str,
118149
weights: Weights,
150+
use_ep: bool = False,
151+
ep_offset: int = 0,
119152
) -> torch.Tensor:
120153
all_weight = None
121154
for i in range(n_experts):
122-
weight = weights.get_weights_row(
123-
f"{prefix}.{i}.{name}",
124-
)
155+
if not use_ep:
156+
weight = weights.get_weights_row(
157+
f"{prefix}.{i}.{name}",
158+
)
159+
else:
160+
weight = weights.get_weights(
161+
f"{prefix}.{i+ep_offset}.{name}",
162+
)
125163

126164
assert isinstance(weight, UnquantizedWeight)
127165

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_deepseek_v2_modeling.py

Lines changed: 110 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,11 +28,12 @@
2828
TensorParallelEmbedding,
2929
TensorParallelRowLinear,
3030
get_linear,
31+
Fp8Linear,
3132
)
3233
from text_generation_server.layers.attention import (
3334
Seqlen,
3435
attention,
35-
paged_attention,
36+
paged_attention_mla,
3637
set_block_mapping,
3738
HPUPagedAttentionMetadata,
3839
)
@@ -44,6 +45,18 @@
4445
import habana_frameworks.torch as htorch
4546

4647

48+
def get_and_maybe_dequant_weights(layer: torch.nn.Module) -> torch.Tensor:
49+
if isinstance(layer, Fp8Linear):
50+
eye = torch.eye(
51+
layer.qweight.shape[-1], dtype=torch.bfloat16, device=layer.qweight.device
52+
)
53+
dequant_weights = layer(eye)
54+
del eye
55+
# standardize to (output, input)
56+
return dequant_weights.T
57+
return layer.weight
58+
59+
4760
class DeepseekV2Config(PretrainedConfig):
4861
def __init__(
4962
self,
@@ -246,6 +259,45 @@ def __init__(
246259
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
247260
).repeat_interleave(self.num_groups)
248261

262+
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj.linear).T
263+
kv_b_proj_weight = kv_b_proj_weight.view(
264+
self.kv_lora_rank,
265+
self.num_heads,
266+
self.qk_nope_head_dim + self.value_head_size,
267+
)
268+
269+
W_UK, W_UV = kv_b_proj_weight.split(
270+
[self.qk_nope_head_dim, self.value_head_size], dim=-1
271+
)
272+
# Convert from (L, N, V) to (N, L, V)
273+
self.W_UV = W_UV.transpose(0, 1)
274+
# Convert from (L, N, P) to (N, P, L)
275+
self.W_UK_T = W_UK.permute(1, 2, 0)
276+
277+
def _q_proj_and_k_up_proj(self, x):
278+
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
279+
q_nope, q_pe = (
280+
q_proj(x)
281+
.view(-1, self.num_heads, self.head_size)
282+
.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
283+
)
284+
285+
# Convert from (B, N, P) to (N, B, P)
286+
q_nope = q_nope.transpose(0, 1)
287+
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
288+
ql_nope = torch.bmm(q_nope, self.W_UK_T)
289+
# Convert from (N, B, L) to (B, N, L)
290+
return ql_nope.transpose(0, 1), q_pe
291+
292+
def _v_up_proj_and_o_proj(self, x):
293+
# Convert from (B, N, L) to (N, B, L)
294+
x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
295+
# Multiply (N, B, L) x (N, L, V) -> (N, B, V)
296+
x = torch.bmm(x, self.W_UV)
297+
# Convert from (N, B, V) to (B, N * V)
298+
x = x.transpose(0, 1).reshape(-1, self.num_heads * self.value_head_size)
299+
return self.o_proj(x)
300+
249301
def forward(
250302
self,
251303
hidden_states: torch.Tensor,
@@ -258,28 +310,28 @@ def forward(
258310
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
259311
):
260312
if self.q_lora_rank is None:
261-
query = self.q_proj(hidden_states)
313+
hidden_states_or_q_c = hidden_states
262314
else:
263-
query = self.q_b_proj(self.q_a_layernorm(self.q_a_proj(hidden_states))[0])
264-
query = query.view(-1, self.num_heads, self.head_size)
265-
266-
_, query_pe = torch.split(
267-
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
268-
)
315+
hidden_states_or_q_c = self.q_a_layernorm(self.q_a_proj(hidden_states))[0]
269316

270317
compressed_kv = self.kv_a_proj_with_mqa(hidden_states)
271318
compressed_kv, key_pe = torch.split(
272319
compressed_kv, [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
273320
)
274321

275322
key_pe = key_pe.view(-1, 1, self.qk_rope_head_dim)
276-
kv = self.kv_b_proj(self.kv_a_layernorm(compressed_kv.contiguous())[0]).view(
277-
-1, self.num_key_value_heads, self.qk_nope_head_dim + self.value_head_size
278-
)
323+
kv_c_normed = self.kv_a_layernorm(compressed_kv.contiguous())[0]
279324

280-
key_nope, value = torch.split(
281-
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
282-
)
325+
# Prefill
326+
if cu_seqlen_prefill is not None:
327+
q_proj = self.q_proj if self.q_lora_rank is None else self.q_b_proj
328+
query = q_proj(hidden_states_or_q_c)
329+
query = query.view(-1, self.num_heads, self.head_size)
330+
query_nope, query_pe = torch.split(
331+
query, [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
332+
)
333+
else:
334+
query_nope, query_pe = self._q_proj_and_k_up_proj(hidden_states_or_q_c)
283335

284336
batch_size, heads, head_dim = query_pe.shape
285337
query_pe = (
@@ -294,33 +346,47 @@ def forward(
294346
.reshape(batch_size, heads, head_dim)
295347
)
296348
self.rotary_emb(query_pe, key_pe, cos, sin)
297-
298-
query[..., self.qk_nope_head_dim :] = query_pe
299-
key = torch.empty_like(query)
300-
key[..., : self.qk_nope_head_dim] = key_nope
301-
key[..., self.qk_nope_head_dim :] = key_pe
302-
303-
# We need to pad the heads because Flash Attention does not support
304-
# qk and v with different head sizes.
305-
query = torch.nn.functional.pad(
306-
query, (0, self.head_pad_size - self.head_size), value=0
307-
)
308-
key = torch.nn.functional.pad(
309-
key, (0, self.head_pad_size - self.head_size), value=0
310-
)
311-
value = torch.nn.functional.pad(
312-
value, (0, self.head_pad_size - self.value_head_size), value=0
349+
latent_vec_k = torch.concat(
350+
(kv_c_normed, key_pe.view(-1, self.qk_rope_head_dim)), dim=-1
313351
)
352+
latent_vec_k = latent_vec_k.view(-1, self.qk_rope_head_dim + self.kv_lora_rank)
353+
354+
latent_vec_k = latent_vec_k.unflatten(0, (slots.size(0), -1))
314355

315356
kv_cache.store(
316-
key=key,
317-
value=value,
357+
key=latent_vec_k,
358+
value=None,
318359
slots=slots,
319360
kv_scales=self.kv_scales,
320361
)
321362

322-
# Prefill
323363
if cu_seqlen_prefill is not None:
364+
kv = self.kv_b_proj(kv_c_normed).view(
365+
-1,
366+
self.num_key_value_heads,
367+
self.qk_nope_head_dim + self.value_head_size,
368+
)
369+
370+
key_nope, value = torch.split(
371+
kv, [self.qk_nope_head_dim, self.value_head_size], dim=-1
372+
)
373+
query[..., self.qk_nope_head_dim :] = query_pe
374+
key = torch.empty_like(query)
375+
key[..., : self.qk_nope_head_dim] = key_nope
376+
key[..., self.qk_nope_head_dim :] = key_pe
377+
378+
# We need to pad the heads because Flash Attention does not support
379+
# qk and v with different head sizes.
380+
query = torch.nn.functional.pad(
381+
query, (0, self.head_pad_size - self.head_size), value=0
382+
)
383+
key = torch.nn.functional.pad(
384+
key, (0, self.head_pad_size - self.head_size), value=0
385+
)
386+
value = torch.nn.functional.pad(
387+
value, (0, self.head_pad_size - self.value_head_size), value=0
388+
)
389+
324390
# flash attention
325391
attn_output = attention(
326392
query=query,
@@ -331,24 +397,26 @@ def forward(
331397
seqlen=seqlen,
332398
softmax_scale=self.softmax_scale,
333399
)
334-
# Decode
400+
attn_output = attn_output[..., : self.value_head_size]
401+
402+
return self.o_proj(
403+
attn_output.reshape(-1, self.num_heads * self.value_head_size)
404+
)
335405
else:
336-
attn_output = paged_attention(
406+
# Decode
407+
query = torch.cat([query_nope, query_pe], dim=-1)
408+
attn_output = paged_attention_mla(
337409
query,
338410
kv_cache,
339411
self.kv_head_mapping,
340412
self.softmax_scale,
341413
seqlen,
342414
kv_scales=self.kv_scales,
343415
hpu_attention_meta=hpu_attention_meta,
416+
kv_lora_rank=self.kv_lora_rank,
344417
)
345-
346-
# Remove padding.
347-
attn_output = attn_output[..., : self.value_head_size]
348-
349-
return self.o_proj(
350-
attn_output.reshape(-1, self.num_heads * self.value_head_size)
351-
)
418+
attn_output = self._v_up_proj_and_o_proj(attn_output)
419+
return attn_output
352420

353421

354422
class DeepseekV2MLP(nn.Module):

backends/gaudi/server/text_generation_server/models/custom_modeling/flash_qwen3_moe_modeling.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from text_generation_server.layers.attention import (
2222
attention,
2323
paged_attention,
24+
set_block_mapping,
2425
Seqlen,
2526
HPUPagedAttentionMetadata,
2627
)
@@ -466,6 +467,10 @@ def forward(
466467
seqlen: Seqlen,
467468
hpu_attention_meta: Optional[HPUPagedAttentionMetadata],
468469
) -> torch.Tensor:
470+
if hpu_attention_meta is not None:
471+
hpu_attention_meta = set_block_mapping(
472+
hpu_attention_meta, inputs_embeds.shape[0]
473+
)
469474

470475
hidden_states = inputs_embeds
471476

backends/gaudi/server/text_generation_server/models/flash_causal_lm.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1606,7 +1606,7 @@ def init_kv_cache(
16061606
):
16071607
self.kv_cache = []
16081608
empty_cache()
1609-
if self.config.model_type == "deepseek_v3":
1609+
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
16101610
self.kv_cache = [
16111611
KVCompressCache(
16121612
num_blocks=num_blocks,
@@ -1646,7 +1646,7 @@ def warmup(
16461646
# Inspired by the original implementation in [vllm](https://github.com/vllm-project/vllm)
16471647
# Calculate the number of blocks that can be allocated with the free memory
16481648
dtype_size = torch.tensor([], dtype=self.kv_cache_dtype).element_size()
1649-
if self.config.model_type == "deepseek_v3":
1649+
if self.config.model_type in ["deepseek_v3", "deepseek_v2"]:
16501650
cache_block_size = BLOCK_SIZE * (
16511651
self.config.kv_lora_rank + self.config.qk_rope_head_dim
16521652
)

0 commit comments

Comments
 (0)