Skip to content

Conversation

@younesbelkada
Copy link
Contributor

@younesbelkada younesbelkada commented Aug 4, 2022

What does this PR do?

This PR forces the causal mask to stay in torch.uint8. An error occurs when loading a model in half precision since torch_dtype=torch.float16 casts also the buffers in fp16. Here is a minimal script to reproduce the error:

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("Salesforce/codegen-2B-mono")
model = AutoModelForCausalLM.from_pretrained("Salesforce/codegen-2B-mono", device_map="auto", torch_dtype=torch.float16)

text = "def quicksort(l):"

encoded_input = tokenizer(text, return_tensors='pt')
output_sequences = model.generate(input_ids=encoded_input['input_ids'], attention_mask=encoded_input['attention_mask'])
print(tokenizer.decode(output_sequences[0], skip_special_tokens=True))

In a future PR we could address non-casting the buffers (aka keeping them in their native dtype)

Can also confirm the slow tests pass!

cc @ydshieh

- Small hotfix for causal mask for half-precision models
- Explicitly cast the causal mask to uint8 for compatibiliy with `torch.where`
@younesbelkada younesbelkada requested a review from sgugger August 4, 2022 08:19
younesbelkada added a commit to younesbelkada/transformers that referenced this pull request Aug 4, 2022
@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Aug 4, 2022

The documentation is not available anymore as the PR was closed or merged.

# compute causal mask from causal mask buffer
query_length, key_length = query.size(-2), key.size(-2)
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length]
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's have a comment here to explain why we need .to(torch.uint8) 🙏

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added more comments on 8b81ac1
Let me know if anything is unclear!

Copy link
Contributor

@thomasw21 thomasw21 Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Stupid question, why is that not

Suggested change
causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8)
self.causal_mask = self.causal_mask[:, :, key_length - query_length : key_length, :key_length].to(torch.uint8)

feels like something you want to do only once no?

It's a no-op when the tensor is already in the correct dtype.

Copy link
Collaborator

@sgugger sgugger Aug 4, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It won't be needed anymore as we found the root cause with Younes :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for fixing!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's close this one to focus on the right fix @younesbelkada :-)

@sgugger sgugger self-requested a review August 4, 2022 12:23
@younesbelkada
Copy link
Contributor Author

Yeah let's move the discussion to: #18471

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants