File tree Expand file tree Collapse file tree 5 files changed +20
-17
lines changed
paddlenlp/experimental/transformers Expand file tree Collapse file tree 5 files changed +20
-17
lines changed Original file line number Diff line number Diff line change 1818
1919import paddle
2020from paddle .distributed import fleet
21- from predict . predictor import ModelArgument , PredictorArgument , create_predictor
21+ from predictor import ModelArgument , PredictorArgument , create_predictor
2222
2323from paddlenlp .trainer import PdArgumentParser
2424from paddlenlp .utils import llm_utils
Original file line number Diff line number Diff line change 5050 from paddlenlp_ops import cutlass_fp8_fp8_half_gemm_fused as fp8_gemm_fused
5151 else :
5252 from paddle .linalg import fp8_fp8_half_gemm_fused as fp8_gemm_fused
53- from paddlenlp_ops import (
54- dequant_int8 ,
55- encode_rotary_qk ,
56- gemm_dequant ,
57- qkv_transpose_split ,
58- quant_int8 ,
59- rebuild_padding ,
60- transpose_remove_padding ,
61- write_cache_kv ,
62- )
53+ try :
54+ from paddlenlp_ops import (
55+ dequant_int8 ,
56+ encode_rotary_qk ,
57+ gemm_dequant ,
58+ qkv_transpose_split ,
59+ quant_int8 ,
60+ rebuild_padding ,
61+ transpose_remove_padding ,
62+ write_cache_kv ,
63+ )
64+ except :
65+ pass
6366
6467__all__ = [
6568 "MoeConfig" ,
Original file line number Diff line number Diff line change @@ -674,7 +674,7 @@ def __init__(self, config: LlamaConfig):
674674 use_neox_rotary_style = self .use_neox ,
675675 cachekv_int8_type = config .cachekv_int8_type ,
676676 rank_id = config .tensor_parallel_rank ,
677- trans_qkvw = (False if paddle .is_compiled_with_rocm () and self . quant_type == "a8w8" else True ),
677+ trans_qkvw = (False if paddle .is_compiled_with_rocm () and "a8w8" in self . quant_type else True ),
678678 )
679679
680680 self .set_transformer_block (transformer_config )
@@ -861,7 +861,7 @@ def set_state_dict(self, state_dict):
861861 unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
862862 "llama.layers.{}.self_attn.v_proj.weight" .format (idx )
863863 ]
864- if paddle .is_compiled_with_rocm () and self . quant_type == "a8w8" :
864+ if paddle .is_compiled_with_rocm () and "a8w8" in self . quant_type :
865865 concated_qkv_weight = np .concatenate (
866866 [
867867 unfused_state_dict ["self_attn.q_proj.weight" ],
Original file line number Diff line number Diff line change @@ -338,7 +338,7 @@ def __init__(self, config: MixtralConfig):
338338 use_neox_rotary_style = self .use_neox ,
339339 cachekv_int8_type = config .cachekv_int8_type ,
340340 rank_id = config .tensor_parallel_rank ,
341- trans_qkvw = (False if paddle .is_compiled_with_rocm () and self . quant_type == "a8w8" else True ),
341+ trans_qkvw = (False if paddle .is_compiled_with_rocm () and "a8w8" in self . quant_type else True ),
342342 moe_config = moe_config ,
343343 )
344344
@@ -527,7 +527,7 @@ def set_state_dict(self, state_dict):
527527 unfused_state_dict ["self_attn.v_proj.weight" ] = state_dict [
528528 "mixtral.layers.{}.self_attn.v_proj.weight" .format (idx )
529529 ]
530- if paddle .is_compiled_with_rocm () and self . quant_type == "a8w8" :
530+ if paddle .is_compiled_with_rocm () and "a8w8" in self . quant_type :
531531 concated_qkv_weight = np .concatenate (
532532 [
533533 unfused_state_dict ["self_attn.q_proj.weight" ],
Original file line number Diff line number Diff line change @@ -372,7 +372,7 @@ def __init__(self, config: Qwen2Config):
372372 use_neox_rotary_style = self .use_neox ,
373373 cachekv_int8_type = config .cachekv_int8_type ,
374374 rank_id = config .tensor_parallel_rank ,
375- trans_qkvw = (False if paddle .is_compiled_with_rocm () and self . quant_type == "a8w8" else True ),
375+ trans_qkvw = (False if paddle .is_compiled_with_rocm () and "a8w8" in self . quant_type else True ),
376376 )
377377
378378 self .set_transformer_block (transformer_config )
@@ -433,7 +433,7 @@ def set_state_dict(self, state_dict):
433433 unfused_state_dict ["qwen2.self_attn.v_proj.weight" ] = state_dict [
434434 "qwen2.layers.{}.self_attn.v_proj.weight" .format (idx )
435435 ]
436- if paddle .is_compiled_with_rocm () and ( self . quant_type == "a8w8" or self .quant_type == "a8w8c8" ) :
436+ if paddle .is_compiled_with_rocm () and "a8w8" in self .quant_type :
437437 concated_qkv_weight = np .concatenate (
438438 [
439439 unfused_state_dict ["self_attn.q_proj.weight" ],
You can’t perform that action at this time.
0 commit comments