-
Notifications
You must be signed in to change notification settings - Fork 31.1k
CodeGen Fix causal mask for half precision #18467
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CodeGen Fix causal mask for half precision #18467
Conversation
- Small hotfix for causal mask for half-precision models - Explicitly cast the causal mask to uint8 for compatibiliy with `torch.where`
- check huggingface#18467
|
The documentation is not available anymore as the PR was closed or merged. |
| # compute causal mask from causal mask buffer | ||
| query_length, key_length = query.size(-2), key.size(-2) | ||
| causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length] | ||
| causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's have a comment here to explain why we need .to(torch.uint8) 🙏
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added more comments on 8b81ac1
Let me know if anything is unclear!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Stupid question, why is that not
| causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8) | |
| self.causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8) |
feels like something you want to do only once no?
It's a no-op when the tensor is already in the correct dtype.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It won't be needed anymore as we found the root cause with Younes :-)
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM, thanks for fixing!
sgugger
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's close this one to focus on the right fix @younesbelkada :-)
|
Yeah let's move the discussion to: #18471 |
What does this PR do?
This PR forces the causal mask to stay in
torch.uint8. An error occurs when loading a model in half precision sincetorch_dtype=torch.float16casts also the buffers in fp16. Here is a minimal script to reproduce the error:In a future PR we could address non-casting the buffers (aka keeping them in their native
dtype)Can also confirm the slow tests pass!
cc @ydshieh