Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 8 additions & 5 deletions llm/predict/export_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,11 +56,14 @@ def validate_pdmodel(model_path, model_prefix, device):
os.path.join(model_path, model_prefix), exe
)

for block in net_program.blocks:
ops: list[paddle.framework.Operator] = block.ops
for op in tqdm(ops, desc="checking the validation of ops"):
if op.type.lower() == "print":
logger.warning(f"UNEXPECTED OP<{op.type}> which will reduce the performace of the static model")
if not paddle.framework.use_pir_api():
for block in net_program.blocks:
ops: list[paddle.framework.Operator] = block.ops
for op in tqdm(ops, desc="checking the validation of ops"):
if op.type.lower() == "print":
logger.warning(
f"UNEXPECTED OP<{op.type}> which will reduce the performace of the static model"
)


def main():
Expand Down
21 changes: 14 additions & 7 deletions paddlenlp/generation/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -594,6 +594,12 @@ def _build_fast(self, kwargs):
kwargs["use_fp16_decoding"] = True
self.prepare_fast_entry(kwargs)

def set_pad_token_id(self, pad_token_id, eos_token_id):
if pad_token_id is None and eos_token_id is not None:
print("Setting `pad_token_id` to `eos_token_id`:{} for " "open-end generation.".format(eos_token_id))
pad_token_id = eos_token_id
return pad_token_id

@paddle.no_grad()
def generate(
self,
Expand Down Expand Up @@ -869,9 +875,7 @@ def generate(
"`streamer` cannot be used with beam search (yet!). Make sure that `num_beams` is set to 1."
)

if pad_token_id is None and eos_token_id is not None:
print("Setting `pad_token_id` to `eos_token_id`:{} for " "open-end generation.".format(eos_token_id))
pad_token_id = eos_token_id
pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)

if generation_config.max_length != 0 and generation_config.max_new_tokens == DEFAULT_MAX_NEW_TOKENS:
logger.warning("`max_length` will be deprecated in future releases, use `max_new_tokens` instead.")
Expand Down Expand Up @@ -1233,7 +1237,6 @@ def sample(
)

next_scores = paddle.index_sample(origin_probs, next_tokens)

if eos_token_id is not None:
next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id))

Expand Down Expand Up @@ -1333,6 +1336,7 @@ def sample_d2s(
min_tokens_to_keep=1,
):

pad_token_id = self.set_pad_token_id(pad_token_id, eos_token_id)
logits_processors = logits_processors if logits_processors is not None else LogitsProcessorList()

if paddle.is_tensor(top_k) and not paddle.is_tensor(top_p):
Expand Down Expand Up @@ -1372,7 +1376,9 @@ def _forward_(**args):
del model_inputs["use_cache"]
return self(**model_inputs, **immutable)

def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_flag, model_kwargs):
def _post_process_(
outputs, input_ids, cur_len, origin_len, scores, unfinished_flag, model_kwargs, pad_token_id
):
if isinstance(outputs, tuple):
logits = outputs[0]
elif isinstance(outputs, ModelOutput):
Expand Down Expand Up @@ -1405,7 +1411,6 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

next_scores = paddle.index_sample(origin_probs, next_tokens)
scores = self.update_scores_for_generation(scores, next_scores, cur_len - origin_len, unfinished_flag)

if eos_token_id is not None:
next_tokens = paddle.where(unfinished_flag, next_tokens, paddle.full_like(next_tokens, pad_token_id))

Expand All @@ -1422,7 +1427,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f

outputs = _forward_(**model_kwargs)
input_ids, scores, unfinished_flag, model_kwargs = _post_process_(
outputs, input_ids, cur_len_gpu, origin_len_gpu, scores, unfinished_flag, model_kwargs
outputs, input_ids, cur_len_gpu, origin_len_gpu, scores, unfinished_flag, model_kwargs, pad_token_id
)

if hasattr(paddle.framework, "_no_check_dy2st_diff"):
Expand Down Expand Up @@ -1456,6 +1461,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
scores,
unfinished_flag,
model_kwargs,
pad_token_id,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
Expand All @@ -1469,6 +1475,7 @@ def _post_process_(outputs, input_ids, cur_len, origin_len, scores, unfinished_f
scores,
unfinished_flag,
model_kwargs,
pad_token_id,
)
paddle.increment(cur_len)
paddle.increment(cur_len_gpu)
Expand Down
4 changes: 2 additions & 2 deletions paddlenlp/transformers/llama/modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def scaled_dot_product_attention(
alibi = alibi.reshape([bsz, num_heads, 1, -1])
attn_weights = attn_weights + alibi

if attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
if paddle.in_dynamic_mode() and attn_weights.shape != [bsz, num_heads, q_len, kv_seq_len]:
raise ValueError(
f"Attention weights should be of shape {(bsz, num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.shape}"
Expand All @@ -271,7 +271,7 @@ def scaled_dot_product_attention(
if attention_mask is None:
attention_mask = get_triangle_upper_mask(attn_weights)
attention_mask = attention_mask.reshape([bsz, 1, q_len, kv_seq_len])
if attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
if paddle.in_dynamic_mode() and attention_mask.shape != [bsz, 1, q_len, kv_seq_len]:
raise ValueError(
f"Attention mask should be of shape {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.shape}"
)
Expand Down