Skip to content

Commit 51a13c0

Browse files
committed
fix dsk diff
1 parent 5a1c4ac commit 51a13c0

File tree

3 files changed

+45
-23
lines changed

3 files changed

+45
-23
lines changed

csrc/gpu/step.cu

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
3131
int *used_list_len,
3232
int *free_list,
3333
int *free_list_len,
34+
int64_t *first_token_ids,
3435
const int bsz,
3536
const int block_size,
3637
const int block_num_per_seq,
@@ -43,6 +44,7 @@ __global__ void free_and_dispatch_block(bool *stop_flags,
4344
int *block_table_now = block_tables + tid * block_num_per_seq;
4445
if (stop_flags[tid] && !is_block_step[tid]) {
4546
// 回收block块
47+
first_token_ids[tid] = -1;
4648
const int encoder_block_len = encoder_block_lens[tid];
4749
const int decoder_used_len = used_list_len[tid];
4850
if (decoder_used_len > 0) {
@@ -166,11 +168,11 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
166168
int *encoder_block_lens,
167169
int *used_list_len,
168170
const int64_t *next_tokens,
171+
const int64_t *first_token_ids,
169172
const int bsz,
170173
const int block_num_per_seq,
171174
const int length,
172-
const int pre_id_length,
173-
const int first_token_id) {
175+
const int pre_id_length) {
174176
const int bid = blockIdx.x;
175177
const int tid = threadIdx.x;
176178
__shared__ int ori_free_list_len;
@@ -189,7 +191,8 @@ __global__ void recover_block(int *recover_block_list, // [bsz]
189191
seq_lens_encoder[recover_id] = seq_len;
190192
stop_flags[recover_id] = false;
191193
input_ids_now[ori_seq_len_encoder + step_idx_now - 1] = next_tokens[recover_id]; // next tokens
192-
input_ids_now[0] = first_token_id; // set first prompt token
194+
input_ids_now[0] =
195+
first_token_ids[recover_id]; // set first prompt token
193196
const int ori_free_list_len_tid0 = atomicSub(free_list_len, decoder_used_len);
194197
ori_free_list_len = ori_free_list_len_tid0;
195198
#ifdef DEBUG_STEP
@@ -234,9 +237,9 @@ void StepPaddle(const paddle::Tensor& stop_flags,
234237
const paddle::Tensor& pre_ids,
235238
const paddle::Tensor& step_idx,
236239
const paddle::Tensor& next_tokens,
240+
const paddle::Tensor &first_token_ids,
237241
const int block_size,
238242
const int encoder_decoder_block_num,
239-
const int64_t first_token_id,
240243
const int speculate_step_token_num) {
241244
auto cu_stream = seq_lens_this_time.stream();
242245
const int bsz = seq_lens_this_time.shape()[0];
@@ -264,6 +267,7 @@ void StepPaddle(const paddle::Tensor& stop_flags,
264267
const_cast<int*>(used_list_len.data<int>()),
265268
const_cast<int*>(free_list.data<int>()),
266269
const_cast<int*>(free_list_len.data<int>()),
270+
const_cast<int64_t *>(first_token_ids.data<int64_t>()),
267271
bsz,
268272
block_size,
269273
block_num_per_seq,
@@ -300,11 +304,11 @@ void StepPaddle(const paddle::Tensor& stop_flags,
300304
const_cast<int*>(encoder_block_lens.data<int>()),
301305
const_cast<int*>(used_list_len.data<int>()),
302306
next_tokens.data<int64_t>(),
307+
first_token_ids.data<int64_t>(),
303308
bsz,
304309
block_num_per_seq,
305310
length,
306-
pre_id_length,
307-
first_token_id
311+
pre_id_length
308312
);
309313
#ifdef DEBUG_STEP
310314
#ifdef PADDLE_WITH_HIP
@@ -337,10 +341,10 @@ PD_BUILD_OP(step_paddle)
337341
"input_ids",
338342
"pre_ids",
339343
"step_idx",
340-
"next_tokens"})
344+
"next_tokens",
345+
"first_token_ids",})
341346
.Attrs({"block_size: int",
342347
"encoder_decoder_block_num: int",
343-
"first_token_id: int64_t",
344348
"speculate_step_token_num: int"})
345349
.Outputs({"stop_flags_out",
346350
"seq_lens_this_time_out",
@@ -358,7 +362,8 @@ PD_BUILD_OP(step_paddle)
358362
"used_list_len_out",
359363
"free_list_out",
360364
"free_list_len_out",
361-
"input_ids_out"})
365+
"input_ids_out",
366+
"first_token_ids_out",})
362367
.SetInplaceMap({{"stop_flags", "stop_flags_out"},
363368
{"seq_lens_this_time", "seq_lens_this_time_out"},
364369
{"seq_lens_encoder", "seq_lens_encoder_out"},
@@ -375,5 +380,6 @@ PD_BUILD_OP(step_paddle)
375380
{"used_list_len", "used_list_len_out"},
376381
{"free_list", "free_list_out"},
377382
{"free_list_len", "free_list_len_out"},
378-
{"input_ids", "input_ids_out"}})
383+
{"input_ids", "input_ids_out"},
384+
{"first_token_ids", "first_token_ids_out"}})
379385
.SetKernelFn(PD_KERNEL(StepPaddle));

paddlenlp/experimental/transformers/deepseek_v2/modeling.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,12 @@ def __init__(
9292
* attn_factor
9393
)
9494

95-
cache = self._compute_cos_sin_cache()
95+
cos_cache, sin_cache = self._compute_cos_sin_cache()
9696

97-
self.cos_sin_cache: paddle.Tensor
98-
self.register_buffer("cos_sin_cache", cache, persistable=True)
97+
self.cos_cache: paddle.Tensor
98+
self.register_buffer("cos_cache", cos_cache, persistable=True)
99+
self.sin_cache: paddle.Tensor
100+
self.register_buffer("sin_cache", sin_cache, persistable=True)
99101

100102
def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
101103
pos_freqs = self.base ** (paddle.arange(0, self.rotary_dim, 2, dtype=paddle.float32) / self.rotary_dim)
@@ -114,23 +116,37 @@ def _compute_inv_freq(self, scaling_factor: float) -> paddle.Tensor:
114116
def _compute_cos_sin_cache(self) -> paddle.Tensor:
115117
inv_freq = self._compute_inv_freq(self.scaling_factor)
116118
t = paddle.arange(self.max_position_embeddings * self.scaling_factor, dtype=paddle.float32)
117-
freqs = paddle.einsum("i,j->ij", t, inv_freq)
118-
cos = freqs.cos() * self.mscale
119-
sin = freqs.sin() * self.mscale
120-
cache = paddle.concat((cos, sin), axis=-1)
121-
return cache.cast(self._dtype)
119+
120+
freqs = paddle.outer(t, inv_freq)
121+
emb = paddle.concat((freqs, freqs), axis=-1)
122+
cos = emb.cos() * self.mscale
123+
sin = emb.sin() * self.mscale
124+
125+
return cos.cast(self._dtype) , sin.cast(self._dtype)
122126

123127
def forward(
124128
self,
125129
position_ids: paddle.Tensor,
126130
query: paddle.Tensor,
127131
key: paddle.Tensor,
128132
) -> Tuple[paddle.Tensor, paddle.Tensor]:
129-
from paddlenlp_ops import fused_rotary_position_encoding
133+
cos = self.cos_cache[position_ids].unsqueeze(1)
134+
sin = self.sin_cache[position_ids].unsqueeze(1)
135+
136+
def rotate_half(x):
137+
"""Rotates half the hidden axiss of the input."""
138+
x1 = x[..., : x.shape[-1] // 2]
139+
x2 = x[..., x.shape[-1] // 2 :]
140+
return paddle.concat([-x2, x1], axis=-1) # shape is the same as x
141+
142+
s, h, d = query.shape
143+
query = query.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
144+
145+
s, h, d = key.shape
146+
key = key.reshape([s, h, d // 2, 2]).transpose([0, 1, 3, 2]).reshape([s, h, d])
130147

131-
# In-place operations that update the query and key tensors.
132-
os.environ["stride_in_no_check_dy2st_diff"] = "1"
133-
fused_rotary_position_encoding(query, key, position_ids, self.cos_sin_cache, self.rotary_dim, False)
148+
query = (query * cos) + (rotate_half(query) * sin)
149+
key = (key * cos) + (rotate_half(key) * sin)
134150

135151
return query, key
136152

paddlenlp/experimental/transformers/fused_transformer_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1160,7 +1160,7 @@ def compute_fused_moe(self, tmp_out, i):
11601160
def get_moe_scores(
11611161
gating_output: paddle.Tensor,
11621162
config: MoeConfig,
1163-
) -> (paddle.Tensor, paddle.Tensor):
1163+
) -> tuple[paddle.Tensor, paddle.Tensor]:
11641164

11651165
num_token = gating_output.shape[0]
11661166
num_expert_group = config.num_expert_group

0 commit comments

Comments
 (0)