Skip to content

Commit 73a3db9

Browse files
authored
[DCU] fix DCU w8a8c8 GEMM shape (#9115)
1 parent ba9c345 commit 73a3db9

File tree

5 files changed

+20
-17
lines changed

5 files changed

+20
-17
lines changed

llm/predict/export_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818

1919
import paddle
2020
from paddle.distributed import fleet
21-
from predict.predictor import ModelArgument, PredictorArgument, create_predictor
21+
from predictor import ModelArgument, PredictorArgument, create_predictor
2222

2323
from paddlenlp.trainer import PdArgumentParser
2424
from paddlenlp.utils import llm_utils

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 13 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -50,16 +50,19 @@
5050
from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused
5151
else:
5252
from paddle.linalg import fp8_fp8_half_gemm_fused as fp8_gemm_fused
53-
from paddlenlp_ops import (
54-
dequant_int8,
55-
encode_rotary_qk,
56-
gemm_dequant,
57-
qkv_transpose_split,
58-
quant_int8,
59-
rebuild_padding,
60-
transpose_remove_padding,
61-
write_cache_kv,
62-
)
53+
try:
54+
from paddlenlp_ops import (
55+
dequant_int8,
56+
encode_rotary_qk,
57+
gemm_dequant,
58+
qkv_transpose_split,
59+
quant_int8,
60+
rebuild_padding,
61+
transpose_remove_padding,
62+
write_cache_kv,
63+
)
64+
except:
65+
pass
6366

6467
__all__ = [
6568
"MoeConfig",

paddlenlp/experimental/transformers/llama/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -674,7 +674,7 @@ def __init__(self, config: LlamaConfig):
674674
use_neox_rotary_style=self.use_neox,
675675
cachekv_int8_type=config.cachekv_int8_type,
676676
rank_id=config.tensor_parallel_rank,
677-
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
677+
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
678678
)
679679

680680
self.set_transformer_block(transformer_config)
@@ -861,7 +861,7 @@ def set_state_dict(self, state_dict):
861861
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
862862
"llama.layers.{}.self_attn.v_proj.weight".format(idx)
863863
]
864-
if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8":
864+
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:
865865
concated_qkv_weight = np.concatenate(
866866
[
867867
unfused_state_dict["self_attn.q_proj.weight"],

paddlenlp/experimental/transformers/mixtral/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -338,7 +338,7 @@ def __init__(self, config: MixtralConfig):
338338
use_neox_rotary_style=self.use_neox,
339339
cachekv_int8_type=config.cachekv_int8_type,
340340
rank_id=config.tensor_parallel_rank,
341-
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
341+
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
342342
moe_config=moe_config,
343343
)
344344

@@ -527,7 +527,7 @@ def set_state_dict(self, state_dict):
527527
unfused_state_dict["self_attn.v_proj.weight"] = state_dict[
528528
"mixtral.layers.{}.self_attn.v_proj.weight".format(idx)
529529
]
530-
if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8":
530+
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:
531531
concated_qkv_weight = np.concatenate(
532532
[
533533
unfused_state_dict["self_attn.q_proj.weight"],

paddlenlp/experimental/transformers/qwen2/modeling.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -372,7 +372,7 @@ def __init__(self, config: Qwen2Config):
372372
use_neox_rotary_style=self.use_neox,
373373
cachekv_int8_type=config.cachekv_int8_type,
374374
rank_id=config.tensor_parallel_rank,
375-
trans_qkvw=(False if paddle.is_compiled_with_rocm() and self.quant_type == "a8w8" else True),
375+
trans_qkvw=(False if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type else True),
376376
)
377377

378378
self.set_transformer_block(transformer_config)
@@ -433,7 +433,7 @@ def set_state_dict(self, state_dict):
433433
unfused_state_dict["qwen2.self_attn.v_proj.weight"] = state_dict[
434434
"qwen2.layers.{}.self_attn.v_proj.weight".format(idx)
435435
]
436-
if paddle.is_compiled_with_rocm() and (self.quant_type == "a8w8" or self.quant_type == "a8w8c8"):
436+
if paddle.is_compiled_with_rocm() and "a8w8" in self.quant_type:
437437
concated_qkv_weight = np.concatenate(
438438
[
439439
unfused_state_dict["self_attn.q_proj.weight"],

0 commit comments

Comments
 (0)