28
28
TensorParallelEmbedding ,
29
29
TensorParallelRowLinear ,
30
30
get_linear ,
31
+ Fp8Linear ,
31
32
)
32
33
from text_generation_server .layers .attention import (
33
34
Seqlen ,
34
35
attention ,
35
- paged_attention ,
36
+ paged_attention_mla ,
36
37
set_block_mapping ,
37
38
HPUPagedAttentionMetadata ,
38
39
)
44
45
import habana_frameworks .torch as htorch
45
46
46
47
48
+ def get_and_maybe_dequant_weights (layer : torch .nn .Module ) -> torch .Tensor :
49
+ if isinstance (layer , Fp8Linear ):
50
+ eye = torch .eye (
51
+ layer .qweight .shape [- 1 ], dtype = torch .bfloat16 , device = layer .qweight .device
52
+ )
53
+ dequant_weights = layer (eye )
54
+ del eye
55
+ # standardize to (output, input)
56
+ return dequant_weights .T
57
+ return layer .weight
58
+
59
+
47
60
class DeepseekV2Config (PretrainedConfig ):
48
61
def __init__ (
49
62
self ,
@@ -246,6 +259,45 @@ def __init__(
246
259
0 , self .num_key_value_heads , dtype = torch .int32 , device = weights .device
247
260
).repeat_interleave (self .num_groups )
248
261
262
+ kv_b_proj_weight = get_and_maybe_dequant_weights (self .kv_b_proj .linear ).T
263
+ kv_b_proj_weight = kv_b_proj_weight .view (
264
+ self .kv_lora_rank ,
265
+ self .num_heads ,
266
+ self .qk_nope_head_dim + self .value_head_size ,
267
+ )
268
+
269
+ W_UK , W_UV = kv_b_proj_weight .split (
270
+ [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
271
+ )
272
+ # Convert from (L, N, V) to (N, L, V)
273
+ self .W_UV = W_UV .transpose (0 , 1 )
274
+ # Convert from (L, N, P) to (N, P, L)
275
+ self .W_UK_T = W_UK .permute (1 , 2 , 0 )
276
+
277
+ def _q_proj_and_k_up_proj (self , x ):
278
+ q_proj = self .q_proj if self .q_lora_rank is None else self .q_b_proj
279
+ q_nope , q_pe = (
280
+ q_proj (x )
281
+ .view (- 1 , self .num_heads , self .head_size )
282
+ .split ([self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1 )
283
+ )
284
+
285
+ # Convert from (B, N, P) to (N, B, P)
286
+ q_nope = q_nope .transpose (0 , 1 )
287
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
288
+ ql_nope = torch .bmm (q_nope , self .W_UK_T )
289
+ # Convert from (N, B, L) to (B, N, L)
290
+ return ql_nope .transpose (0 , 1 ), q_pe
291
+
292
+ def _v_up_proj_and_o_proj (self , x ):
293
+ # Convert from (B, N, L) to (N, B, L)
294
+ x = x .view (- 1 , self .num_heads , self .kv_lora_rank ).transpose (0 , 1 )
295
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
296
+ x = torch .bmm (x , self .W_UV )
297
+ # Convert from (N, B, V) to (B, N * V)
298
+ x = x .transpose (0 , 1 ).reshape (- 1 , self .num_heads * self .value_head_size )
299
+ return self .o_proj (x )
300
+
249
301
def forward (
250
302
self ,
251
303
hidden_states : torch .Tensor ,
@@ -258,28 +310,28 @@ def forward(
258
310
hpu_attention_meta : Optional [HPUPagedAttentionMetadata ],
259
311
):
260
312
if self .q_lora_rank is None :
261
- query = self . q_proj ( hidden_states )
313
+ hidden_states_or_q_c = hidden_states
262
314
else :
263
- query = self .q_b_proj (self .q_a_layernorm (self .q_a_proj (hidden_states ))[0 ])
264
- query = query .view (- 1 , self .num_heads , self .head_size )
265
-
266
- _ , query_pe = torch .split (
267
- query , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
268
- )
315
+ hidden_states_or_q_c = self .q_a_layernorm (self .q_a_proj (hidden_states ))[0 ]
269
316
270
317
compressed_kv = self .kv_a_proj_with_mqa (hidden_states )
271
318
compressed_kv , key_pe = torch .split (
272
319
compressed_kv , [self .kv_lora_rank , self .qk_rope_head_dim ], dim = - 1
273
320
)
274
321
275
322
key_pe = key_pe .view (- 1 , 1 , self .qk_rope_head_dim )
276
- kv = self .kv_b_proj (self .kv_a_layernorm (compressed_kv .contiguous ())[0 ]).view (
277
- - 1 , self .num_key_value_heads , self .qk_nope_head_dim + self .value_head_size
278
- )
323
+ kv_c_normed = self .kv_a_layernorm (compressed_kv .contiguous ())[0 ]
279
324
280
- key_nope , value = torch .split (
281
- kv , [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
282
- )
325
+ # Prefill
326
+ if cu_seqlen_prefill is not None :
327
+ q_proj = self .q_proj if self .q_lora_rank is None else self .q_b_proj
328
+ query = q_proj (hidden_states_or_q_c )
329
+ query = query .view (- 1 , self .num_heads , self .head_size )
330
+ query_nope , query_pe = torch .split (
331
+ query , [self .qk_nope_head_dim , self .qk_rope_head_dim ], dim = - 1
332
+ )
333
+ else :
334
+ query_nope , query_pe = self ._q_proj_and_k_up_proj (hidden_states_or_q_c )
283
335
284
336
batch_size , heads , head_dim = query_pe .shape
285
337
query_pe = (
@@ -294,33 +346,47 @@ def forward(
294
346
.reshape (batch_size , heads , head_dim )
295
347
)
296
348
self .rotary_emb (query_pe , key_pe , cos , sin )
297
-
298
- query [..., self .qk_nope_head_dim :] = query_pe
299
- key = torch .empty_like (query )
300
- key [..., : self .qk_nope_head_dim ] = key_nope
301
- key [..., self .qk_nope_head_dim :] = key_pe
302
-
303
- # We need to pad the heads because Flash Attention does not support
304
- # qk and v with different head sizes.
305
- query = torch .nn .functional .pad (
306
- query , (0 , self .head_pad_size - self .head_size ), value = 0
307
- )
308
- key = torch .nn .functional .pad (
309
- key , (0 , self .head_pad_size - self .head_size ), value = 0
310
- )
311
- value = torch .nn .functional .pad (
312
- value , (0 , self .head_pad_size - self .value_head_size ), value = 0
349
+ latent_vec_k = torch .concat (
350
+ (kv_c_normed , key_pe .view (- 1 , self .qk_rope_head_dim )), dim = - 1
313
351
)
352
+ latent_vec_k = latent_vec_k .view (- 1 , self .qk_rope_head_dim + self .kv_lora_rank )
353
+
354
+ latent_vec_k = latent_vec_k .unflatten (0 , (slots .size (0 ), - 1 ))
314
355
315
356
kv_cache .store (
316
- key = key ,
317
- value = value ,
357
+ key = latent_vec_k ,
358
+ value = None ,
318
359
slots = slots ,
319
360
kv_scales = self .kv_scales ,
320
361
)
321
362
322
- # Prefill
323
363
if cu_seqlen_prefill is not None :
364
+ kv = self .kv_b_proj (kv_c_normed ).view (
365
+ - 1 ,
366
+ self .num_key_value_heads ,
367
+ self .qk_nope_head_dim + self .value_head_size ,
368
+ )
369
+
370
+ key_nope , value = torch .split (
371
+ kv , [self .qk_nope_head_dim , self .value_head_size ], dim = - 1
372
+ )
373
+ query [..., self .qk_nope_head_dim :] = query_pe
374
+ key = torch .empty_like (query )
375
+ key [..., : self .qk_nope_head_dim ] = key_nope
376
+ key [..., self .qk_nope_head_dim :] = key_pe
377
+
378
+ # We need to pad the heads because Flash Attention does not support
379
+ # qk and v with different head sizes.
380
+ query = torch .nn .functional .pad (
381
+ query , (0 , self .head_pad_size - self .head_size ), value = 0
382
+ )
383
+ key = torch .nn .functional .pad (
384
+ key , (0 , self .head_pad_size - self .head_size ), value = 0
385
+ )
386
+ value = torch .nn .functional .pad (
387
+ value , (0 , self .head_pad_size - self .value_head_size ), value = 0
388
+ )
389
+
324
390
# flash attention
325
391
attn_output = attention (
326
392
query = query ,
@@ -331,24 +397,26 @@ def forward(
331
397
seqlen = seqlen ,
332
398
softmax_scale = self .softmax_scale ,
333
399
)
334
- # Decode
400
+ attn_output = attn_output [..., : self .value_head_size ]
401
+
402
+ return self .o_proj (
403
+ attn_output .reshape (- 1 , self .num_heads * self .value_head_size )
404
+ )
335
405
else :
336
- attn_output = paged_attention (
406
+ # Decode
407
+ query = torch .cat ([query_nope , query_pe ], dim = - 1 )
408
+ attn_output = paged_attention_mla (
337
409
query ,
338
410
kv_cache ,
339
411
self .kv_head_mapping ,
340
412
self .softmax_scale ,
341
413
seqlen ,
342
414
kv_scales = self .kv_scales ,
343
415
hpu_attention_meta = hpu_attention_meta ,
416
+ kv_lora_rank = self .kv_lora_rank ,
344
417
)
345
-
346
- # Remove padding.
347
- attn_output = attn_output [..., : self .value_head_size ]
348
-
349
- return self .o_proj (
350
- attn_output .reshape (- 1 , self .num_heads * self .value_head_size )
351
- )
418
+ attn_output = self ._v_up_proj_and_o_proj (attn_output )
419
+ return attn_output
352
420
353
421
354
422
class DeepseekV2MLP (nn .Module ):
0 commit comments