-
-
Notifications
You must be signed in to change notification settings - Fork 4k
Add an int64 path for mlp kernels #3614
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
unsloth/kernels/geglu.py
Outdated
| device = gate.device | ||
| out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device) | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why -1024? Is it maybe hd?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.
unsloth/kernels/geglu.py
Outdated
| batch_seq_len, hd = e.shape | ||
| n_elements = e.numel() | ||
| grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),) | ||
| if n_elements <= (2**31) - 1024: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe move (2**31) to a global var
unsloth/kernels/swiglu.py
Outdated
| e, | ||
| g, | ||
| n_elements, | ||
| BLOCK_SIZE: tl.constexpr, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later
c008eca to
262ada3
Compare
262ada3 to
833d91f
Compare
The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.