|
| 1 | +import torch |
| 2 | +import triton |
| 3 | +import triton.language as tl |
| 4 | + |
| 5 | + |
| 6 | +@triton.jit |
| 7 | +def _triton_qwen2vl_mrope( |
| 8 | + q_ptr, |
| 9 | + k_ptr, |
| 10 | + cos, |
| 11 | + sin, |
| 12 | + sl, |
| 13 | + n_qh: tl.constexpr, |
| 14 | + n_kh: tl.constexpr, |
| 15 | + hd: tl.constexpr, |
| 16 | + pad_n_qh: tl.constexpr, |
| 17 | + pad_n_kh: tl.constexpr, |
| 18 | + pad_hd: tl.constexpr, |
| 19 | + mrope_section_t: tl.constexpr, |
| 20 | + mrope_section_h: tl.constexpr, |
| 21 | + BLOCK_SIZE: tl.constexpr, |
| 22 | + BACKWARD_PASS: tl.constexpr = False, |
| 23 | +): |
| 24 | + pid = tl.program_id(0) |
| 25 | + |
| 26 | + # locate start address |
| 27 | + q_ptr = q_ptr + pid * (n_qh * hd) |
| 28 | + k_ptr = k_ptr + pid * (n_kh * hd) |
| 29 | + |
| 30 | + # #################################################################### |
| 31 | + # get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position |
| 32 | + # m of this program instance |
| 33 | + # #################################################################### |
| 34 | + |
| 35 | + # 1. program instances are laid out in a 1D vector of size bsz * seq_len, which |
| 36 | + # effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension |
| 37 | + # being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index |
| 38 | + # and pid % sl to get the sequence index. |
| 39 | + # 2. We only need the left half of cos and sin matrix because the right half is just |
| 40 | + # a clone of the left half. |
| 41 | + t_end = mrope_section_t |
| 42 | + h_end = t_end + mrope_section_h |
| 43 | + |
| 44 | + cos_row_idx = pid % sl |
| 45 | + t_cos = cos + cos_row_idx * hd |
| 46 | + h_cos = t_cos + sl * hd |
| 47 | + w_cos = h_cos + sl * hd |
| 48 | + t_sin = sin + cos_row_idx * hd |
| 49 | + h_sin = t_sin + sl * hd |
| 50 | + w_sin = h_sin + sl * hd |
| 51 | + |
| 52 | + cos_offsets = tl.arange(0, pad_hd // 2) |
| 53 | + t_mask = cos_offsets < t_end |
| 54 | + h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end) |
| 55 | + w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2) |
| 56 | + t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0) |
| 57 | + h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0) |
| 58 | + w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0) |
| 59 | + t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0) |
| 60 | + h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0) |
| 61 | + w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0) |
| 62 | + cos_row = t_cos_row + h_cos_row + w_cos_row |
| 63 | + sin_row = t_sin_row + h_sin_row + w_sin_row |
| 64 | + |
| 65 | + # #################################################################### |
| 66 | + # Load the left and right half of q and k for the current |
| 67 | + # program instance (i.e. for the current token) separately |
| 68 | + # #################################################################### |
| 69 | + # left half of the head |
| 70 | + first_half_q_offsets = ( |
| 71 | + tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] |
| 72 | + ) |
| 73 | + first_half_k_offsets = ( |
| 74 | + tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :] |
| 75 | + ) |
| 76 | + first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & ( |
| 77 | + tl.arange(0, pad_hd // 2)[None, :] < hd // 2 |
| 78 | + ) |
| 79 | + first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & ( |
| 80 | + tl.arange(0, pad_hd // 2)[None, :] < hd // 2 |
| 81 | + ) |
| 82 | + q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to( |
| 83 | + sin_row.dtype |
| 84 | + ) |
| 85 | + k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to( |
| 86 | + sin_row.dtype |
| 87 | + ) |
| 88 | + |
| 89 | + # right half of the head |
| 90 | + second_half_q_offsets = first_half_q_offsets + (hd // 2) |
| 91 | + second_half_k_offsets = first_half_k_offsets + (hd // 2) |
| 92 | + second_q_mask = first_q_mask |
| 93 | + second_k_mask = first_k_mask |
| 94 | + q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to( |
| 95 | + sin_row.dtype |
| 96 | + ) |
| 97 | + k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to( |
| 98 | + sin_row.dtype |
| 99 | + ) |
| 100 | + |
| 101 | + if not BACKWARD_PASS: |
| 102 | + # y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin] |
| 103 | + new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row |
| 104 | + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) |
| 105 | + new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row |
| 106 | + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) |
| 107 | + |
| 108 | + new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row |
| 109 | + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) |
| 110 | + new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row |
| 111 | + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) |
| 112 | + else: |
| 113 | + # with some math, we can get: |
| 114 | + # dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin] |
| 115 | + new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row |
| 116 | + tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask) |
| 117 | + new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row |
| 118 | + tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask) |
| 119 | + |
| 120 | + new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row |
| 121 | + tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask) |
| 122 | + new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row |
| 123 | + tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask) |
| 124 | + |
| 125 | + |
| 126 | +def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section): |
| 127 | + |
| 128 | + # transpose it back to the physical shape because Triton looks at the physical storage |
| 129 | + # note: q and k are incontiguous before the transformation and will become contiguous after transpose |
| 130 | + q = q.transpose(1, 2) |
| 131 | + k = k.transpose(1, 2) |
| 132 | + |
| 133 | + batch_size, seq_len, n_q_head, head_dim = q.shape |
| 134 | + n_kv_head = k.shape[2] |
| 135 | + pad_hd = triton.next_power_of_2(head_dim) |
| 136 | + pad_n_q_head = triton.next_power_of_2(n_q_head) |
| 137 | + pad_n_kv_head = triton.next_power_of_2(n_kv_head) |
| 138 | + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) |
| 139 | + |
| 140 | + n_row = batch_size * seq_len |
| 141 | + |
| 142 | + # ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous |
| 143 | + q = q.contiguous() |
| 144 | + k = k.contiguous() |
| 145 | + cos = cos.contiguous() |
| 146 | + sin = sin.contiguous() |
| 147 | + |
| 148 | + _triton_qwen2vl_mrope[(n_row,)]( |
| 149 | + q, |
| 150 | + k, |
| 151 | + cos, |
| 152 | + sin, |
| 153 | + seq_len, |
| 154 | + n_q_head, |
| 155 | + n_kv_head, |
| 156 | + head_dim, |
| 157 | + pad_n_q_head, |
| 158 | + pad_n_kv_head, |
| 159 | + pad_hd, |
| 160 | + mrope_section[0], |
| 161 | + mrope_section[1], |
| 162 | + BLOCK_SIZE=BLOCK_SIZE, |
| 163 | + BACKWARD_PASS=False, |
| 164 | + ) |
| 165 | + return q.transpose(1, 2), k.transpose(1, 2), cos, sin |
| 166 | + |
| 167 | + |
| 168 | +def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section): |
| 169 | + dq = dq.transpose(1, 2) |
| 170 | + dk = dk.transpose(1, 2) |
| 171 | + |
| 172 | + batch_size, seq_len, n_q_head, head_dim = dq.shape |
| 173 | + n_kv_head = dk.shape[2] |
| 174 | + pad_hd = triton.next_power_of_2(head_dim) |
| 175 | + pad_n_q_head = triton.next_power_of_2(n_q_head) |
| 176 | + pad_n_kv_head = triton.next_power_of_2(n_kv_head) |
| 177 | + BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head) |
| 178 | + |
| 179 | + n_row = batch_size * seq_len |
| 180 | + |
| 181 | + # ensure dq and dk are contiguous |
| 182 | + dq = dq.contiguous() |
| 183 | + dk = dk.contiguous() |
| 184 | + |
| 185 | + # backward is similar to forward except swapping few ops |
| 186 | + _triton_qwen2vl_mrope[(n_row,)]( |
| 187 | + dq, |
| 188 | + dk, |
| 189 | + cos, |
| 190 | + sin, |
| 191 | + seq_len, |
| 192 | + n_q_head, |
| 193 | + n_kv_head, |
| 194 | + head_dim, |
| 195 | + pad_n_q_head, |
| 196 | + pad_n_kv_head, |
| 197 | + pad_hd, |
| 198 | + mrope_section[0], |
| 199 | + mrope_section[1], |
| 200 | + BLOCK_SIZE=BLOCK_SIZE, |
| 201 | + BACKWARD_PASS=True, |
| 202 | + ) |
| 203 | + return dq.transpose(1, 2), dk.transpose(1, 2) |
| 204 | + |
| 205 | + |
| 206 | +class LigerQwen2VLMRopeFunction(torch.autograd.Function): |
| 207 | + """ |
| 208 | + Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation. |
| 209 | +
|
| 210 | + Please find the corresponding HuggingFace implementation here: |
| 211 | + https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py |
| 212 | + """ |
| 213 | + |
| 214 | + @staticmethod |
| 215 | + def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1): |
| 216 | + """ |
| 217 | + q size: (bsz, n_q_head, seq_len, head_dim) |
| 218 | + k size: (bsz, n_kv_head, seq_len, head_dim) |
| 219 | + cos size: (3, 1, seq_len, head_dim) |
| 220 | + sin size: (3, 1, seq_len, head_dim) |
| 221 | + """ |
| 222 | + q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section) |
| 223 | + ctx.save_for_backward(cos, sin) |
| 224 | + ctx.mrope_section = mrope_section |
| 225 | + return q, k |
| 226 | + |
| 227 | + def backward(ctx, dq, dk): |
| 228 | + """ |
| 229 | + dq size: (bsz, n_q_head, seq_len, head_dim) |
| 230 | + dk size: (bsz, n_kv_head, seq_len, head_dim) |
| 231 | + cos size: (3, 1, seq_len, head_dim) |
| 232 | + sin size: (3, 1, seq_len, head_dim) |
| 233 | + """ |
| 234 | + |
| 235 | + cos, sin = ctx.saved_tensors |
| 236 | + mrope_section = ctx.mrope_section |
| 237 | + dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section) |
| 238 | + return dq, dk, None, None, None, None |
0 commit comments