Fix illegal memory access through off-by-one error in num_splits_dynamic_ptr init #1747
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There is an off-by-one error in flash_api.cpp / set_params_fprop() which can lead to memory access violations in our codebase.
Error description
in set_params_fprop, if scheduler_needs_semaphore == False and use_dynamic_split == True, the size of the tile_count_semaphore tensor is initialized here to be equal to the batch size. Which is theoretically sufficient:
flash-attention/hopper/flash_api.cpp
Line 959 in adf27d1
But on line 975 a bit further below, even if scheduler_needs_semaphore==False, there is an offset of 1 being used to initialize num_splits_dynamic_ptr based off the raw data of the tile_count_semaphore tensor.
If num_splits_dynamic_ptr is now again being accessed at it's supposedly last valid element at an index equal to the batch size - 1, an illegal memory access occurs. Since it's just an off-by-one error, this might rarely be detectable, but it led to (rare) crashes and numerical issues in our CI. It could be detected by running some of our tests with "compute-sanitizer --padding 128 ... " while setting PYTORCH_NO_CUDA_MEMORY_CACHING=1 to disable pytorch's caching allocator ( without that, the access usually still hit memory that belonged to a valid allocation even if it was out of bounds ).
flash-attention/hopper/flash_api.cpp
Line 975 in adf27d1