Skip to content

Conversation

@timmy-feng
Copy link

Added support for paged attention by doing the following:

  • Pre-allocate pages in the scheduler thread before calling run_batch. Since we do not know the fill status of the most recent page (it is still running on the GPU), we allocate for the worst case number of pages starting from a new page.
  • Alter the assign_draft_cache_locs kernel in the draft decode to prepend the remaining unused cache locs from the previous page. We don't have to worry about freeing excess here because the allocator state is restored after draft.
  • Add a merge_cache_loc kernel to the verify to prepend the remaining unused cache locs from the previous page. We store the excess pages into an evict_cache_loc tensor, which is combined with the other pages that are evicted after accepting tokens.

TODO

Correctness has been achieved for all attention backends other than FA3.

The code is correct when FA3 is used for the draft decode + extend, but not verify.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants