Skip to content

Conversation

@mmathew23
Copy link
Collaborator

The llama mlp kernels produce nans with extremely long context length. This is happens when the num_elements is greater than 2**31. In these cases it's best to calculate offsets with tl.int64 instead of int32. This PR will route to int64 kernels if the num_elements is big enough.

device = gate.device
out = torch.empty((batch, seq_len, hd), dtype = gate.dtype, device = device)
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why -1024? Is it maybe hd?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes I forgot to account for hd. The idea is that I wanted to add a buffer just to be safe.

batch_seq_len, hd = e.shape
n_elements = e.numel()
grid = lambda meta: (triton.cdiv(n_elements, meta["BLOCK_SIZE"]),)
if n_elements <= (2**31) - 1024:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe move (2**31) to a global var

e,
g,
n_elements,
BLOCK_SIZE: tl.constexpr,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there is actually a way to use 1 kernel only and dispatch, but for now this is fine - we can refactor later

@mmathew23 mmathew23 force-pushed the tiled/contextlen branch 2 times, most recently from c008eca to 262ada3 Compare November 19, 2025 17:24
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants