Skip to content

Conversation

@yuanlehome
Copy link
Collaborator

@yuanlehome yuanlehome commented Aug 6, 2024

PR types

Others

PR changes

Others

Description

重构一些代码以及修复一些bug。

特别需要提示的是,静态图推理时指定了src_length和max_length的话,动转静的时候同样也需要指定。

@paddle-bot
Copy link

paddle-bot bot commented Aug 6, 2024

Thanks for your contribution!

@yuanlehome yuanlehome mentioned this pull request Aug 6, 2024
@codecov
Copy link

codecov bot commented Aug 6, 2024

Codecov Report

Attention: Patch coverage is 5.55556% with 17 lines in your changes missing coverage. Please review.

Project coverage is 55.40%. Comparing base (678843e) to head (88ea827).
Report is 246 commits behind head on develop.

Files with missing lines Patch % Lines
...dlenlp/experimental/transformers/llama/modeling.py 0.00% 11 Missing ⚠️
paddlenlp/experimental/model_utils.py 0.00% 6 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           develop    #8879      +/-   ##
===========================================
+ Coverage    55.29%   55.40%   +0.10%     
===========================================
  Files          631      632       +1     
  Lines        98888    99762     +874     
===========================================
+ Hits         54681    55271     +590     
- Misses       44207    44491     +284     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@yuanlehome yuanlehome changed the title TEMP [LLM Inference] Refactor BlockInferencePredictor Aug 8, 2024
)
predictor.model.config.save_pretrained(export_args.output_path)
predictor.model.generation_config.save_pretrained(export_args.output_path)
if predictor.generation_config is not None:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复generation_config.json保存不正确的bug


PD_BUILD_OP(get_padding_offset_v2)
.Inputs({"input_ids", "token_num", "cum_offsets", "seq_len"})
.Inputs({"input_ids", "cum_offsets", "token_num", "seq_len"})
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复算子输入位置错乱问题

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

很好奇之前怎么能跑通

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

影响name顺序而已,tensor顺序没有错

seq_len * rotary_emb_dims,
last_dim);
NeoXRotaryKernel<<<grid, BlockSize, 0, cu_stream>>>(
NeoXRotaryKernel<<<grid_k, BlockSize, 0, cu_stream>>>(
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修复gqa kernel计算bug

Comment on lines -34 to -35
get_default_max_decoding_length,
get_default_max_encoding_length,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉自动配置src_length和max_length逻辑,通过给src_length和max_length指定默认值,因为自动配置并不合理不能兼容所有模型,某些情况下会报错

tokenizer.init_chat_template(chat_template_file)


def get_model_max_position_embeddings(config: PretrainedConfig) -> Optional[int]:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些函数不要删掉,我看代码库别的地方也有调用的,例如ppo之类的。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的

class InferencePredictorMixin:
class InferencePredictorMixin(BasePredictor):
def __init__(self, config: PredictorArgument, tokenizer: PretrainedTokenizer):
BasePredictor.__init__(self, config, tokenizer)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

所以generation_config就通过BasePredictor来生成是吧

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是的

self.input_ids = paddle.full(
shape=[config.batch_size, config.total_max_length], fill_value=self.tokenizer.pad_token_id, dtype="int64"
)
self.model_inputs = {}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

变量名从 self.inputs 改为了 self.model_inputs,会影响什么逻辑吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会的

length = len(input_ids)
self.inputs["input_ids"][i : i + 1, :length] = input_ids
self.inputs["penalty_score"][i : i + 1] = self.config.repetition_penalty
self.inputs["frequency_score"][i : i + 1] = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这些不需要保留么?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

还是说前面 init_model_inputs 写了,这里就是把重复的删掉了

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对的,不需要保留

ffn1_weight_tensor = paddle.to_tensor(concated_ffn1_weight)

qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight)
qkv_weight_tensor = paddle.to_tensor(concated_qkv_weight).cast(paddle.get_default_dtype())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为啥要cast到 default dtype

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

为了同时能跑bf16/fp16

linear_weight_tensor = paddle.to_tensor(state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)])
linear_weight_tensor = paddle.to_tensor(
state_dict["llama.layers.{}.self_attn.o_proj.weight".format(idx)]
).cast(paddle.get_default_dtype())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同问

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同回

src_length: int = field(default=None, metadata={"help": "The max length of source text."})
max_length: int = field(default=None, metadata={"help": "the max length for decoding."})
src_length: int = field(default=4096, metadata={"help": "The max length of source text."})
min_length: int = field(default=1, metadata={"help": "the min length for decoding."})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

min_length和max_length,名字改成 min_decode_length, max_decode_length吧,不然容易引起误会。

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改不动,太多地方用到这个名字了ooo


def test_blha(self):
self.run_predictor({"inference_model": True, "block_attn": True})
self.run_predictor({"inference_model": True, "block_attn": True, "src_length": 1024, "max_length": 48})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

max_length改成max_decode_length

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改不动

shape=[config.batch_size, 1], fill_value=config.temperature, dtype="float32"
)
self.model_inputs["eos_token_id"] = paddle.to_tensor(
np.array(get_eos_token_id(self.tokenizer, self.generation_config)).reshape(-1, 1).astype("int64")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个在多batch下不会有问题吗?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不会的,kernel处理的时候不分batch的

Copy link
Contributor

@wawltor wawltor left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@wawltor wawltor merged commit 5bc040a into PaddlePaddle:develop Aug 12, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants