@@ -358,11 +358,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
358358 nullptr ,
359359 reinterpret_cast <const XPUSCType*>(sin_data),
360360 reinterpret_cast <const XPUSCType*>(cos_data),
361+ nullptr ,
361362 reinterpret_cast <XPUType*>(out_q->data ()),
362363 nullptr ,
363364 {batch_size, seq_len, num_heads, head_dim},
364365 {batch_size, seq_len, 1 , head_dim},
365366 {},
367+ 0 ,
366368 " BLHD" ,
367369 -1 ,
368370 10000 .0f );
@@ -374,11 +376,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
374376 reinterpret_cast <const XPUType*>(in_k->data ()),
375377 reinterpret_cast <const XPUSCType*>(sin_data),
376378 reinterpret_cast <const XPUSCType*>(cos_data),
379+ nullptr ,
377380 reinterpret_cast <XPUType*>(out_q->data ()),
378381 reinterpret_cast <XPUType*>(out_k->data ()),
379382 {batch_size, seq_len, num_heads, head_dim},
380383 {batch_size, seq_len, 1 , head_dim},
381384 {},
385+ 0 ,
382386 " BLHD" ,
383387 num_heads_k,
384388 10000 .0f );
@@ -392,11 +396,13 @@ void XPUFusedRotaryHalf(const Context& dev_ctx,
392396 nullptr ,
393397 reinterpret_cast <const XPUSCType*>(sin_data),
394398 reinterpret_cast <const XPUSCType*>(cos_data),
399+ nullptr ,
395400 reinterpret_cast <XPUType*>(out_v->data ()),
396401 nullptr ,
397402 {batch_size, seq_len, num_heads_v, head_dim},
398403 {batch_size, seq_len, 1 , head_dim},
399404 {},
405+ 0 ,
400406 " BLHD" ,
401407 -1 ,
402408 10000 .0f );
0 commit comments