- 
                Notifications
    You must be signed in to change notification settings 
- Fork 3.1k
[NEW Feature] 新增基于hook的refined_recompute支持 #9396
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
62fc783
              9d2632a
              56203a1
              506c6bf
              f394bbc
              11075fc
              997cf5f
              1dad92d
              246a913
              f348cc1
              9d34d2f
              da6c9cb
              6b6654d
              ed4addb
              7b8d1c6
              7064804
              b8671f1
              44b2389
              418a259
              9f5e306
              File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | 
|---|---|---|
|  | @@ -51,6 +51,7 @@ def swiglu(x, y=None): | |
| except: | ||
| flash_attention = None | ||
|  | ||
| from paddlenlp.transformers.refined_recompute import no_recompute | ||
| There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 为什么要叫no_recompute,感觉怪怪的 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 要么改成skip_recompute也行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. recompute(func, xxxxx) vs no_recompute(func, xxxxxx) | ||
| from paddlenlp.transformers.ring_flash_attention import RingFlashAttention | ||
|  | ||
|  | ||
|  | @@ -174,6 +175,7 @@ def fusion_flash_attention( | |
| sequence_parallel=False, | ||
| reshard_layer=None, | ||
| npu_is_casual=False, | ||
| skip_recompute=False, | ||
| ): | ||
| bsz, q_len, num_heads, head_dim = query_states.shape | ||
| _, kv_seq_len, _, _ = value_states.shape | ||
|  | @@ -257,28 +259,34 @@ def fusion_flash_attention( | |
| attn_mask_startend_row_indices = paddle.unsqueeze(attn_mask_startend_row_indices, axis=1) | ||
|  | ||
| if hasattr(F, "flashmask_attention"): | ||
| attn_output = F.flashmask_attention( | ||
| attn_output = no_recompute( | ||
| F.flashmask_attention, | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| startend_row_indices=attn_mask_startend_row_indices.unsqueeze(-1), | ||
| causal=True, | ||
| enable=skip_recompute, | ||
| ) | ||
| else: | ||
| attn_output = F.flash_attention_with_sparse_mask( | ||
| attn_output = no_recompute( | ||
| F.flash_attention_with_sparse_mask, | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask_start_row_indices=attn_mask_startend_row_indices, | ||
| is_causal=True, | ||
| enable=skip_recompute, | ||
| ) | ||
| else: | ||
| attn_output = F.scaled_dot_product_attention( | ||
| attn_output = no_recompute( | ||
| F.scaled_dot_product_attention, | ||
| query_states, | ||
| key_states, | ||
| value_states, | ||
| attn_mask=attention_mask, | ||
| is_causal=query_states.shape[1] != 1, | ||
| enable=skip_recompute, | ||
| ) | ||
| attn_weights = None | ||
|  | ||
|  | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的配置信息会传到下游任务里面吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要 _set_unsavable_keys 吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
不需要,这个zhonghui比较清楚用法,我看了一下实现可以满足需求。1是加了llmmetaclass,2是LlmMetaConfig.set_llm_config(model_config, training_args)
@DataClass
@llmmetaclass
@add_start_docstrings(TrainingArguments.doc)
class TrainingArguments(TrainingArguments):