Skip to content

Commit bd71d04

Browse files
authored
[BugFix] Fix RoPE kernel on long sequences(vllm-project#2164)
1 parent 0f7e679 commit bd71d04

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

csrc/pos_encoding_kernels.cu

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
4343
scalar_t* __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
4444
const scalar_t* __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
4545
const int rot_dim,
46-
const int query_stride,
47-
const int key_stride,
46+
const int64_t query_stride,
47+
const int64_t key_stride,
4848
const int num_heads,
4949
const int num_kv_heads,
5050
const int head_size) {
@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
6060
const int nq = num_heads * embed_dim;
6161
for (int i = threadIdx.x; i < nq; i += blockDim.x) {
6262
const int head_idx = i / embed_dim;
63-
const int token_head = token_idx * query_stride + head_idx * head_size;
63+
const int64_t token_head = token_idx * query_stride + head_idx * head_size;
6464
const int rot_offset = i % embed_dim;
6565
apply_rotary_embedding<scalar_t, IS_NEOX>(query + token_head, cos_ptr,
6666
sin_ptr, rot_offset, embed_dim);
@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
6969
const int nk = num_kv_heads * embed_dim;
7070
for (int i = threadIdx.x; i < nk; i += blockDim.x) {
7171
const int head_idx = i / embed_dim;
72-
const int token_head = token_idx * key_stride + head_idx * head_size;
72+
const int64_t token_head = token_idx * key_stride + head_idx * head_size;
7373
const int rot_offset = i % embed_dim;
7474
apply_rotary_embedding<scalar_t, IS_NEOX>(key + token_head, cos_ptr,
7575
sin_ptr, rot_offset, embed_dim);
@@ -89,8 +89,8 @@ void rotary_embedding(
8989
int rot_dim = cos_sin_cache.size(1);
9090
int num_heads = query.size(-1) / head_size;
9191
int num_kv_heads = key.size(-1) / head_size;
92-
int query_stride = query.stride(-2);
93-
int key_stride = key.stride(-2);
92+
int64_t query_stride = query.stride(-2);
93+
int64_t key_stride = key.stride(-2);
9494

9595
dim3 grid(num_tokens);
9696
dim3 block(std::min(num_heads * rot_dim / 2, 512));

0 commit comments

Comments
 (0)