Skip to content

Conversation

zheliuyu
Copy link
Contributor

@zheliuyu zheliuyu commented Dec 9, 2024

What does this PR do?

add features: Ascend NPU supports SDPA.

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Did you read the contributor guideline,
    Pull Request section?
  • Was this discussed/approved via a Github issue or the forum? Please add a link
    to it if that's the case.
  • Did you make sure to update the documentation with your changes? Here are the
    documentation guidelines, and
    here are tips on formatting docstrings.
  • Did you write any new necessary tests?

notes

Ascend NPU requires torch>=2.1.0 to use SDPA in Transformers.

@zheliuyu zheliuyu changed the title Ascend NPU support SDPA NPU support SDPA Dec 14, 2024
@zheliuyu zheliuyu marked this pull request as draft December 21, 2024 07:16
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR ! Is there a link that shows that npu is compatible with sdpa from torch 2.1.0 ? Also let us know when this is ready to be reviewed !

@zheliuyu
Copy link
Contributor Author

zheliuyu commented Dec 24, 2024

Thanks for the PR ! Is there a link that shows that npu is compatible with sdpa from torch 2.1.0 ? Also let us know when this is ready to be reviewed !

@SunMarc Thanks for your reply, this PR ready to be reviewed. Some explanations and tests are as follows.

NPU supports SDPA in torch>=2.1.0

To use SDPA in NPU, simply import torch_npu.
For GPU

import torch
import torch.nn.functional as F


query = torch.ones(1, 2, dtype=torch.float16, device="cuda")
key = torch.ones(1, 2, dtype=torch.float16, device="cuda")
value = torch.ones(1, 2, dtype=torch.float16, device="cuda")

output = F.scaled_dot_product_attention(query, key, value)
print("torch version: ", torch.__version__)
print("result: ", output)
torch version:  2.1.0+cu121
result:  tensor([[1., 1.]], device='cuda:0', dtype=torch.float16)

For NPU

import torch
import torch_npu
import torch.nn.functional as F


query = torch.ones(1, 2, dtype=torch.float16, device="npu:0")
key = torch.ones(1, 2, dtype=torch.float16, device="npu:0")
value = torch.ones(1, 2, dtype=torch.float16, device="npu:0")

output = F.scaled_dot_product_attention(query, key, value)
print("torch version: ", torch.__version__)
print("torch_npu version: ", torch_npu.__version__)
print("result: ", output)
torch version:  2.1.0
torch_npu version: 2.1.0
result:  tensor([[1., 1.]], device='npu:0', dtype=torch.float16)

NPU is OK with non-contiguous inputs in torch>=2.1.0

According to the issue 112577, transformers requires torch>=2.1.1 to avoid a numerical issue in SDPA with non-contiguous.

Test this issue in the same code. NPU can avoid this issue in torch>=2.1.0.

query_sdpa torch.Size([1, 1, 2048])
key_sdpa torch.Size([1, 16, 128])
value_sdpa torch.Size([1, 16, 128])
attention_mask_sdpa torch.Size([1, 1, 1, 16])
attention_mask_sdpa tensor([[[[True, True, True, True, True, True, True, True, True, True, True,
           True, True, True, True, True]]]])
---- non_contig_cpu_math
query contiguous True
key contiguous False
value contiguous False
---- contig_cpu_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_npu_math
query contiguous True
key contiguous False
value contiguous False
---- contig_npu_math
query contiguous True
key contiguous True
value contiguous True
---- non_contig_npu_memeff
query contiguous True
key contiguous False
value contiguous False
---- contig_npu_memeff
query contiguous True
key contiguous True
value contiguous True


cpu non-contig/contig: mean abs-diff tensor(0.)
cpu non-contig/contig: mean rel-diff tensor(0.)
npu non-contig/contig math: mean abs-diff tensor(0., device='npu:0')
npu non-contig/contig math: mean rel-diff tensor(0., device='npu:0')
npu non-contig/contig memeff: mean abs-diff tensor(0., device='npu:0')
npu non-contig/contig memeff: mean rel-diff tensor(0., device='npu:0')

Allclose CPU non-contig/contig: True
Allclose NPU math non-contig/contig: True
Allclose NPU memeff non-contig/contig: True

@zheliuyu zheliuyu closed this Dec 24, 2024
@zheliuyu zheliuyu reopened this Dec 24, 2024
@zheliuyu zheliuyu marked this pull request as ready for review December 24, 2024 08:06
Copy link
Member

@SunMarc SunMarc left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice thanks for the explanation !

@SunMarc SunMarc requested a review from ArthurZucker December 24, 2024 11:33
@zheliuyu
Copy link
Contributor Author

zheliuyu commented Jan 5, 2025

Nice thanks for the explanation !

@SunMarc @ArthurZucker
Hi, is this PR Okay to be merged ? Is there anything I can help?^^

@SunMarc
Copy link
Member

SunMarc commented Jan 6, 2025

Okay to be merged for me ! cc @ArthurZucker

@HuggingFaceDocBuilderDev

The docs for this PR live here. All of your documentation changes will be reflected on that endpoint. The docs are available until 30 days after the last update.

Copy link
Collaborator

@ArthurZucker ArthurZucker left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good

@ArthurZucker ArthurZucker merged commit ed73ae2 into huggingface:main Jan 7, 2025
25 checks passed
AlanPonnachan pushed a commit to AlanPonnachan/transformers that referenced this pull request Jan 7, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants