@@ -43,8 +43,8 @@ __global__ void rotary_embedding_kernel(
43
43
scalar_t * __restrict__ key, // [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
44
44
const scalar_t * __restrict__ cos_sin_cache, // [max_position, 2, rot_dim // 2]
45
45
const int rot_dim,
46
- const int query_stride,
47
- const int key_stride,
46
+ const int64_t query_stride,
47
+ const int64_t key_stride,
48
48
const int num_heads,
49
49
const int num_kv_heads,
50
50
const int head_size) {
@@ -60,7 +60,7 @@ __global__ void rotary_embedding_kernel(
60
60
const int nq = num_heads * embed_dim;
61
61
for (int i = threadIdx .x ; i < nq; i += blockDim .x ) {
62
62
const int head_idx = i / embed_dim;
63
- const int token_head = token_idx * query_stride + head_idx * head_size;
63
+ const int64_t token_head = token_idx * query_stride + head_idx * head_size;
64
64
const int rot_offset = i % embed_dim;
65
65
apply_rotary_embedding<scalar_t , IS_NEOX>(query + token_head, cos_ptr,
66
66
sin_ptr, rot_offset, embed_dim);
@@ -69,7 +69,7 @@ __global__ void rotary_embedding_kernel(
69
69
const int nk = num_kv_heads * embed_dim;
70
70
for (int i = threadIdx .x ; i < nk; i += blockDim .x ) {
71
71
const int head_idx = i / embed_dim;
72
- const int token_head = token_idx * key_stride + head_idx * head_size;
72
+ const int64_t token_head = token_idx * key_stride + head_idx * head_size;
73
73
const int rot_offset = i % embed_dim;
74
74
apply_rotary_embedding<scalar_t , IS_NEOX>(key + token_head, cos_ptr,
75
75
sin_ptr, rot_offset, embed_dim);
@@ -89,8 +89,8 @@ void rotary_embedding(
89
89
int rot_dim = cos_sin_cache.size (1 );
90
90
int num_heads = query.size (-1 ) / head_size;
91
91
int num_kv_heads = key.size (-1 ) / head_size;
92
- int query_stride = query.stride (-2 );
93
- int key_stride = key.stride (-2 );
92
+ int64_t query_stride = query.stride (-2 );
93
+ int64_t key_stride = key.stride (-2 );
94
94
95
95
dim3 grid (num_tokens);
96
96
dim3 block (std::min (num_heads * rot_dim / 2 , 512 ));
0 commit comments