Skip to content

Commit b903e7e

Browse files
committed
Support Qwen2-VL's multimodal RoPE implementation
1 parent 1aa3d83 commit b903e7e

File tree

5 files changed

+413
-2
lines changed

5 files changed

+413
-2
lines changed

src/liger_kernel/ops/qwen2vl_mrope.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import torch
2+
import triton
3+
import triton.language as tl
4+
5+
6+
@triton.jit
7+
def _triton_qwen2vl_mrope(
8+
q_ptr,
9+
k_ptr,
10+
cos,
11+
sin,
12+
sl,
13+
n_qh: tl.constexpr,
14+
n_kh: tl.constexpr,
15+
hd: tl.constexpr,
16+
pad_n_qh: tl.constexpr,
17+
pad_n_kh: tl.constexpr,
18+
pad_hd: tl.constexpr,
19+
mrope_section_t: tl.constexpr,
20+
mrope_section_h: tl.constexpr,
21+
BLOCK_SIZE: tl.constexpr,
22+
BACKWARD_PASS: tl.constexpr = False,
23+
):
24+
pid = tl.program_id(0)
25+
26+
# locate start address
27+
q_ptr = q_ptr + pid * (n_qh * hd)
28+
k_ptr = k_ptr + pid * (n_kh * hd)
29+
30+
# ####################################################################
31+
# get the cos(mθ_{i...d/2}) and sin(mθ_{i...d/2}) for token position
32+
# m of this program instance
33+
# ####################################################################
34+
35+
# 1. program instances are laid out in a 1D vector of size bsz * seq_len, which
36+
# effectively represents a 2D grid of size [bsz, seq_len] with seq_len dimension
37+
# being the fastest changing dimension. Thus we can simply do pid // sl to get the batch index
38+
# and pid % sl to get the sequence index.
39+
# 2. We only need the left half of cos and sin matrix because the right half is just
40+
# a clone of the left half.
41+
t_end = mrope_section_t
42+
h_end = t_end + mrope_section_h
43+
44+
cos_row_idx = pid % sl
45+
t_cos = cos + cos_row_idx * hd
46+
h_cos = t_cos + sl * hd
47+
w_cos = h_cos + sl * hd
48+
t_sin = sin + cos_row_idx * hd
49+
h_sin = t_sin + sl * hd
50+
w_sin = h_sin + sl * hd
51+
52+
cos_offsets = tl.arange(0, pad_hd // 2)
53+
t_mask = cos_offsets < t_end
54+
h_mask = (t_end <= cos_offsets) & (cos_offsets < h_end)
55+
w_mask = (h_end <= cos_offsets) & (cos_offsets < hd // 2)
56+
t_cos_row = tl.load(t_cos + cos_offsets, mask=t_mask, other=0)
57+
h_cos_row = tl.load(h_cos + cos_offsets, mask=h_mask, other=0)
58+
w_cos_row = tl.load(w_cos + cos_offsets, mask=w_mask, other=0)
59+
t_sin_row = tl.load(t_sin + cos_offsets, mask=t_mask, other=0)
60+
h_sin_row = tl.load(h_sin + cos_offsets, mask=h_mask, other=0)
61+
w_sin_row = tl.load(w_sin + cos_offsets, mask=w_mask, other=0)
62+
cos_row = t_cos_row + h_cos_row + w_cos_row
63+
sin_row = t_sin_row + h_sin_row + w_sin_row
64+
65+
# ####################################################################
66+
# Load the left and right half of q and k for the current
67+
# program instance (i.e. for the current token) separately
68+
# ####################################################################
69+
# left half of the head
70+
first_half_q_offsets = (
71+
tl.arange(0, pad_n_qh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
72+
)
73+
first_half_k_offsets = (
74+
tl.arange(0, pad_n_kh)[:, None] * hd + tl.arange(0, pad_hd // 2)[None, :]
75+
)
76+
first_q_mask = (tl.arange(0, pad_n_qh)[:, None] < n_qh) & (
77+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
78+
)
79+
first_k_mask = (tl.arange(0, pad_n_kh)[:, None] < n_kh) & (
80+
tl.arange(0, pad_hd // 2)[None, :] < hd // 2
81+
)
82+
q_tile_1 = tl.load(q_ptr + first_half_q_offsets, mask=first_q_mask, other=0).to(
83+
sin_row.dtype
84+
)
85+
k_tile_1 = tl.load(k_ptr + first_half_k_offsets, mask=first_k_mask, other=0).to(
86+
sin_row.dtype
87+
)
88+
89+
# right half of the head
90+
second_half_q_offsets = first_half_q_offsets + (hd // 2)
91+
second_half_k_offsets = first_half_k_offsets + (hd // 2)
92+
second_q_mask = first_q_mask
93+
second_k_mask = first_k_mask
94+
q_tile_2 = tl.load(q_ptr + second_half_q_offsets, mask=second_q_mask, other=0).to(
95+
sin_row.dtype
96+
)
97+
k_tile_2 = tl.load(k_ptr + second_half_k_offsets, mask=second_k_mask, other=0).to(
98+
sin_row.dtype
99+
)
100+
101+
if not BACKWARD_PASS:
102+
# y = [x1, x2] * [cos, cos] + [-x2, x1] * [sin, sin]
103+
new_q_tile_1 = q_tile_1 * cos_row - q_tile_2 * sin_row
104+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
105+
new_q_tile_2 = q_tile_2 * cos_row + q_tile_1 * sin_row
106+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
107+
108+
new_k_tile_1 = k_tile_1 * cos_row - k_tile_2 * sin_row
109+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
110+
new_k_tile_2 = k_tile_2 * cos_row + k_tile_1 * sin_row
111+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
112+
else:
113+
# with some math, we can get:
114+
# dy = [dx1, dx2] * [cos, cos] + [-dx2, dx1] * [-sin, -sin]
115+
new_q_tile_1 = q_tile_1 * cos_row + q_tile_2 * sin_row
116+
tl.store(q_ptr + first_half_q_offsets, new_q_tile_1, mask=first_q_mask)
117+
new_q_tile_2 = q_tile_2 * cos_row - q_tile_1 * sin_row
118+
tl.store(q_ptr + second_half_q_offsets, new_q_tile_2, mask=second_q_mask)
119+
120+
new_k_tile_1 = k_tile_1 * cos_row + k_tile_2 * sin_row
121+
tl.store(k_ptr + first_half_k_offsets, new_k_tile_1, mask=first_k_mask)
122+
new_k_tile_2 = k_tile_2 * cos_row - k_tile_1 * sin_row
123+
tl.store(k_ptr + second_half_k_offsets, new_k_tile_2, mask=second_k_mask)
124+
125+
126+
def qwen2vl_mrope_forward(q, k, cos, sin, mrope_section):
127+
128+
# transpose it back to the physical shape because Triton looks at the physical storage
129+
# note: q and k are incontiguous before the transformation and will become contiguous after transpose
130+
q = q.transpose(1, 2)
131+
k = k.transpose(1, 2)
132+
133+
batch_size, seq_len, n_q_head, head_dim = q.shape
134+
n_kv_head = k.shape[2]
135+
pad_hd = triton.next_power_of_2(head_dim)
136+
pad_n_q_head = triton.next_power_of_2(n_q_head)
137+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
138+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
139+
140+
n_row = batch_size * seq_len
141+
142+
# ensure tensors passed into the kernel are contiguous. It will be no-op if they are already contiguous
143+
q = q.contiguous()
144+
k = k.contiguous()
145+
cos = cos.contiguous()
146+
sin = sin.contiguous()
147+
148+
_triton_qwen2vl_mrope[(n_row,)](
149+
q,
150+
k,
151+
cos,
152+
sin,
153+
seq_len,
154+
n_q_head,
155+
n_kv_head,
156+
head_dim,
157+
pad_n_q_head,
158+
pad_n_kv_head,
159+
pad_hd,
160+
mrope_section[0],
161+
mrope_section[1],
162+
BLOCK_SIZE=BLOCK_SIZE,
163+
BACKWARD_PASS=False,
164+
)
165+
return q.transpose(1, 2), k.transpose(1, 2), cos, sin
166+
167+
168+
def qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section):
169+
dq = dq.transpose(1, 2)
170+
dk = dk.transpose(1, 2)
171+
172+
batch_size, seq_len, n_q_head, head_dim = dq.shape
173+
n_kv_head = dk.shape[2]
174+
pad_hd = triton.next_power_of_2(head_dim)
175+
pad_n_q_head = triton.next_power_of_2(n_q_head)
176+
pad_n_kv_head = triton.next_power_of_2(n_kv_head)
177+
BLOCK_SIZE = max(pad_n_q_head, pad_n_kv_head)
178+
179+
n_row = batch_size * seq_len
180+
181+
# ensure dq and dk are contiguous
182+
dq = dq.contiguous()
183+
dk = dk.contiguous()
184+
185+
# backward is similar to forward except swapping few ops
186+
_triton_qwen2vl_mrope[(n_row,)](
187+
dq,
188+
dk,
189+
cos,
190+
sin,
191+
seq_len,
192+
n_q_head,
193+
n_kv_head,
194+
head_dim,
195+
pad_n_q_head,
196+
pad_n_kv_head,
197+
pad_hd,
198+
mrope_section[0],
199+
mrope_section[1],
200+
BLOCK_SIZE=BLOCK_SIZE,
201+
BACKWARD_PASS=True,
202+
)
203+
return dq.transpose(1, 2), dk.transpose(1, 2)
204+
205+
206+
class LigerQwen2VLMRopeFunction(torch.autograd.Function):
207+
"""
208+
Triton implementation of the Qwen2VL Multimodal Rotary Positional Embedding (M-RoPE) operation.
209+
210+
Please find the corresponding HuggingFace implementation here:
211+
https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py
212+
"""
213+
214+
@staticmethod
215+
def forward(ctx, q, k, cos, sin, mrope_section, unsqueeze_dim=1):
216+
"""
217+
q size: (bsz, n_q_head, seq_len, head_dim)
218+
k size: (bsz, n_kv_head, seq_len, head_dim)
219+
cos size: (3, 1, seq_len, head_dim)
220+
sin size: (3, 1, seq_len, head_dim)
221+
"""
222+
q, k, cos, sin = qwen2vl_mrope_forward(q, k, cos, sin, mrope_section)
223+
ctx.save_for_backward(cos, sin)
224+
ctx.mrope_section = mrope_section
225+
return q, k
226+
227+
def backward(ctx, dq, dk):
228+
"""
229+
dq size: (bsz, n_q_head, seq_len, head_dim)
230+
dk size: (bsz, n_kv_head, seq_len, head_dim)
231+
cos size: (3, 1, seq_len, head_dim)
232+
sin size: (3, 1, seq_len, head_dim)
233+
"""
234+
235+
cos, sin = ctx.saved_tensors
236+
mrope_section = ctx.mrope_section
237+
dq, dk = qwen2vl_mrope_backward(dq, dk, cos, sin, mrope_section)
238+
return dq, dk, None, None, None, None

src/liger_kernel/transformers/functional.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from liger_kernel.ops.jsd import LigerJSDFunction
1111
from liger_kernel.ops.kl_div import LigerKLDivLossFunction
1212
from liger_kernel.ops.layer_norm import LigerLayerNormFunction
13+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
1314
from liger_kernel.ops.rms_norm import LigerRMSNormFunction
1415
from liger_kernel.ops.rope import LigerRopeFunction
1516
from liger_kernel.ops.swiglu import LigerSiLUMulFunction
@@ -19,6 +20,7 @@
1920
liger_geglu = LigerGELUMulFunction.apply
2021
liger_rms_norm = LigerRMSNormFunction.apply
2122
liger_rope = LigerRopeFunction.apply
23+
liger_qwen2vl_mrope = LigerQwen2VLMRopeFunction.apply
2224
liger_layer_norm = LigerLayerNormFunction.apply
2325
liger_kl_div = LigerKLDivLossFunction.apply
2426
liger_jsd = LigerJSDFunction.apply

src/liger_kernel/transformers/monkey_patch.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636
from liger_kernel.transformers.model.qwen2 import (
3737
lce_forward_deprecated as qwen2_lce_forward_deprecated,
3838
)
39+
from liger_kernel.transformers.qwen2vl_mrope import liger_multimodal_rotary_pos_emb
3940
from liger_kernel.transformers.rms_norm import LigerRMSNorm
4041
from liger_kernel.transformers.rope import liger_rotary_pos_emb
4142
from liger_kernel.transformers.swiglu import (
@@ -641,6 +642,7 @@ def apply_liger_kernel_to_qwen2(
641642

642643

643644
def apply_liger_kernel_to_qwen2_vl(
645+
rope: bool = True,
644646
cross_entropy: bool = False,
645647
fused_linear_cross_entropy: bool = True,
646648
rms_norm: bool = True,
@@ -675,8 +677,10 @@ def apply_liger_kernel_to_qwen2_vl(
675677
lce_forward as qwen2_vl_lce_forward,
676678
)
677679

678-
# TODO: Support Qwen2-VL's multimodal RoPE implementation
679-
680+
if rope:
681+
modeling_qwen2_vl.apply_multimodal_rotary_pos_emb = (
682+
liger_multimodal_rotary_pos_emb
683+
)
680684
if rms_norm:
681685
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/qwen2_vl/modeling_qwen2_vl.py#L439
682686
modeling_qwen2_vl.Qwen2RMSNorm = LigerRMSNorm
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
from liger_kernel.ops.qwen2vl_mrope import LigerQwen2VLMRopeFunction
2+
3+
4+
def liger_multimodal_rotary_pos_emb(q, k, cos, sin, mrope_section, unsqueeze_dim=1):
5+
"""
6+
Applies Multimodal Rotary Positional Embedding (M-RoPE) operation to query and key states.
7+
8+
Args:
9+
q (torch.Tensor): The query tensor of shape (bsz, n_q_head, seq_len, head_dim).
10+
k (torch.Tensor): The key tensor of shape (bsz, n_kv_head, seq_len, head_dim).
11+
cos (torch.Tensor): The cosine tensor of shape (3, 1, seq_len, head_dim).
12+
sin (torch.Tensor): The sine tensor of shape (3, 1, seq_len, head_dim).
13+
mrope_section (List[int]): The multimodal rope section for channel dimension of temporal, height and width in rope calculation.
14+
unsqueeze_dim (int, optional): The dimension to unsqueeze. Defaults to 1.
15+
16+
Returns:
17+
Tuple[torch.Tensor, torch.Tensor]: The query and key tensors after applying the M-RoPE operation.
18+
"""
19+
20+
return LigerQwen2VLMRopeFunction.apply(q, k, cos, sin, mrope_section, unsqueeze_dim)

0 commit comments

Comments
 (0)