Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions csrc/src/flash.h
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,9 @@ struct Flash_fwd_params : public QKV_params, public Mask_params, public Bias_par

bool unpadded_lse; // For varlen paths: LSE is in [nheads, total_seqlen_q] format instead of [b, nheads, seqlen_q].
bool seqlenq_ngroups_swapped; // q has been transposed from (b, 1, (nheads_kv ngroups), d) to (b, ngroups, nheads_kv, d).

// Unified sparse mask for advanced masking strategies
void * __restrict__ sparse_mask_ptr; // Pointer to UnifiedSparseMask object
};

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down
2 changes: 2 additions & 0 deletions csrc/src/flash_bwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include "utils.h"
#include "softmax.h"
#include "mask.h"
#include "unified_sparse_mask.h"
#include "mask_factory.h"

namespace FLASH_NAMESPACE {

Expand Down
37 changes: 31 additions & 6 deletions csrc/src/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
#include "utils.h"
#include "softmax.h"
#include "mask.h"
#include "unified_sparse_mask.h"
#include "mask_factory.h"

namespace FLASH_NAMESPACE {

Expand Down Expand Up @@ -394,9 +396,16 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
clear(acc_o);

FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;

// Init dynamic mask processor
// Init dynamic mask processor with optional unified sparse mask (first instance)
const UnifiedSparseMask* sparse_mask_ptr = nullptr;
// Check if unified sparse mask is provided in params
if (params.sparse_mask_ptr != nullptr) {
sparse_mask_ptr = static_cast<const UnifiedSparseMask*>(params.sparse_mask_ptr);
}

FLASH_NAMESPACE::Mask<Is_causal> mask(
binfo.actual_seqlen_k, binfo.actual_seqlen_q, sparse_mask_ptr
);
binfo.actual_seqlen_k, binfo.actual_seqlen_q
);

Expand Down Expand Up @@ -459,11 +468,20 @@ inline __device__ void compute_attn_1rowblock(const Params &params, const int bi
Tensor tSrBias_copy_view = smem_thr_copy_Bias.retile_D(tSrBias);
cute::copy(smem_tiled_copy_Bias, tSsBias, tSrBias_copy_view);

// Scale attention scores and apply mask/bias
mask.template apply_mask<Is_causal, Is_even_MN>(
// Scale attention scores and apply mask/bias with unified sparse mask block-level skipping
bool block_has_activity = mask.template apply_mask_with_skip_check<Is_causal, Is_even_MN>(
acc_s, tSrMask, tSrBias, params.scale_softmax,
n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16
n_block * kBlockN, m_block * kBlockM + (tidx / 32) * 16 + (tidx % 32) / 4, kNWarps * 16,
m_block, n_block, kBlockM, kBlockN
);

// If unified sparse mask indicates no activity, skip further computation for this block
if (!block_has_activity) {
// Block is completely masked out - zero the accumulator and skip softmax/output computation
clear(acc_s);
// Continue to next iteration without softmax computation
continue;
}

FLASH_NAMESPACE::cp_async_wait<0>();
__syncthreads();
Expand Down Expand Up @@ -1045,8 +1063,15 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
FLASH_NAMESPACE::Softmax<2 * size<1>(acc_o)> softmax;

// Init dynamic mask processor
// Init dynamic mask processor with optional unified sparse mask (second instance)
const UnifiedSparseMask* sparse_mask_ptr2 = nullptr;
// Check if unified sparse mask is provided in params
if (params.sparse_mask_ptr != nullptr) {
sparse_mask_ptr2 = static_cast<const UnifiedSparseMask*>(params.sparse_mask_ptr);
}

FLASH_NAMESPACE::Mask<Is_causal> mask(
binfo.actual_seqlen_k, binfo.actual_seqlen_q
binfo.actual_seqlen_k, binfo.actual_seqlen_q, sparse_mask_ptr2
);

// For performance reason, we separate out two kinds of iterations:
Expand Down
1,668 changes: 1,668 additions & 0 deletions csrc/src/flash_fwd_kernel.h.bak

Large diffs are not rendered by default.

58 changes: 55 additions & 3 deletions csrc/src/mask.h
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

#pragma once
#include "namespace_config.h"
#include "unified_sparse_mask.h"

#include <cute/tensor.hpp>

Expand Down Expand Up @@ -57,15 +58,66 @@ __forceinline__ __device__ void apply_mask(
template <bool Is_causal>
struct Mask {
const int max_seqlen_k, max_seqlen_q;

const UnifiedSparseMask* sparse_mask; // Optional unified sparse mask

__forceinline__ __device__ Mask(
const int max_seqlen_k,
const int max_seqlen_q
const int max_seqlen_q,
const UnifiedSparseMask* sparse_mask_ptr = nullptr
) // Constructor
: max_seqlen_k(max_seqlen_k)
, max_seqlen_q(max_seqlen_q) {
, max_seqlen_q(max_seqlen_q)
, sparse_mask(sparse_mask_ptr) {
};

// New unified mask application with block-level skipping
template <bool Causal_mask=false, bool Is_even_MN=true, typename TensorType, typename MaskType, typename BiasType>
__forceinline__ __device__ bool apply_mask_with_skip_check(
TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N)
MaskType &tSrMask, // Attention Mask (MMA=4, MMA_M, MMA_N)
BiasType &tSrBias, // Attention Bias (MMA=4, MMA_M, MMA_N)
const float scale_softmax, // Scale for softmax
const int col_idx_offset_, // Column index offset
const int row_idx_offset, // Row index offset
const int warp_row_stride, // Warp row stride
const int query_block_idx, // Query block index for sparse mask
const int key_block_idx, // Key block index for sparse mask
const int block_size_m = 128, // Block size M
const int block_size_n = 128 // Block size N
) {
static_assert(TensorType::rank == 3, "tensor_ must be 3D Tensor");
static_assert(MaskType::rank == 3, "Mask must be 3D Tensor");
static_assert(BiasType::rank == 3, "Bias must be 3D Tensor");
static_assert(decltype(size<0>(tensor_))::value == 4, "First dimension must be 4");

// Step 1: Check if we should skip this block entirely using unified sparse mask
bool any_active = true;
if (sparse_mask != nullptr) {
any_active = sparse_mask->is_block_active(query_block_idx, key_block_idx);
if (!any_active) {
// Block is completely masked - skip all computation
return false;
}
}

// Step 2: Apply traditional mask logic for active blocks
apply_mask<Causal_mask, Is_even_MN>(
tensor_, tSrMask, tSrBias, scale_softmax,
col_idx_offset_, row_idx_offset, warp_row_stride
);

// Step 3: If we have a unified sparse mask, perform more detailed activity check
if (sparse_mask != nullptr) {
// For non-parametric masks, do OR reduction on the actual mask tile
MaskType mask_type = sparse_mask->get_mask_type();
if (mask_type != MaskType::PARAMETRIC_CAUSAL && mask_type != MaskType::PARAMETRIC_WINDOW) {
any_active = sparse_mask->compute_block_activity_fast(tSrMask, query_block_idx, key_block_idx);
}
}

return any_active;
}

template <bool Causal_mask=false, bool Is_even_MN=true, typename TensorType, typename MaskType, typename BiasType>
__forceinline__ __device__ void apply_mask(
TensorType &tensor_, // acc_s (attention scores, MMA=4, MMA_M, MMA_N)
Expand Down
Loading