@@ -263,6 +263,8 @@ impl EvalOp for VPTQGemm {
263263 . into_tensor ( ) ;
264264 }
265265
266+ let data_type = * fdtypes. iter ( ) . next ( ) . unwrap ( ) ;
267+
266268 if enable_norm {
267269 qweight = ( qweight. into_array :: < f32 > ( ) ? * weight_scale. into_array :: < f32 > ( ) ?
268270 + weight_bias. into_array :: < f32 > ( ) ?)
@@ -278,33 +280,36 @@ impl EvalOp for VPTQGemm {
278280
279281 let & n = qweight. shape ( ) . last ( ) . unwrap ( ) ;
280282
281- let ( & [ m, k] , out_shape, offset ) = match ishape. len ( ) {
283+ let ( & [ m, k] , out_shape) = match ishape. len ( ) {
282284 2 => {
283285 let & [ m, k] = ishape else {
284- bail ! ( "unexpected rank: {:?}" , input . len( ) ) ;
286+ bail ! ( "unexpected rank: {:?}" , ishape . len( ) ) ;
285287 } ;
286- ( & [ m, k] , vec ! [ m, n] , 0usize )
288+ ( & [ m, k] , vec ! [ m, n] )
287289 }
288290 3 => {
289291 let & [ b, m, k] = ishape else {
290- bail ! ( "unexpected rank: {:?}" , input . len( ) ) ;
292+ bail ! ( "unexpected rank: {:?}" , ishape . len( ) ) ;
291293 } ;
292- ( & [ m, k] , vec ! [ b, m, n] , 1usize )
294+ ( & [ m, k] , vec ! [ b, m, n] )
293295 }
294296 _ => {
295297 bail ! ( "unexpected rank {:?}" , ishape. len( ) )
296298 }
297299 } ;
298300
299- let mmm = op. mmm ( * fdtypes. iter ( ) . next ( ) . unwrap ( ) , Some ( m) , Some ( k) , Some ( n) ) . unwrap ( ) ;
301+ let input_offset = input. rank ( ) - 2 ;
302+ let weight_offset = qweight. rank ( ) - 2 ;
300303
304+ let mmm = op. mmm ( data_type, Some ( m) , Some ( k) , Some ( n) ) . unwrap ( ) ;
301305 let ( pack_a, pack_b) = & mmm. packings ( ) [ 0 ] ;
302- let cstore = unsafe { mmm. c_view ( 0 + offset, 1 + offset) } ;
303306
304- let a = pack_a. prepare_tensor ( & input, 1 + offset, 0 + offset) ?;
305- let b = pack_b. prepare_tensor ( & qweight, 0 + offset, 1 + offset) ?;
306- unsafe {
307- let out = Tensor :: uninitialized :: < f32 > ( out_shape. iter ( ) . as_slice ( ) . try_into ( ) ?) ?;
307+ let cstore = unsafe { mmm. c_view ( input_offset, 1 + input_offset) } ;
308+
309+ let a = pack_a. prepare_tensor ( & input, 1 + input_offset, input_offset) ?;
310+ let b = pack_b. prepare_tensor ( & qweight, weight_offset, 1 + weight_offset) ?;
311+ let last = unsafe {
312+ let out = Tensor :: uninitialized :: < f32 > ( out_shape. iter ( ) . as_slice ( ) ) ?;
308313 let non_linear = & [
309314 FusedSpec :: AddMatMul {
310315 a : tract_linalg:: mmm:: AsInputValue :: Owned ( a) ,
@@ -315,8 +320,11 @@ impl EvalOp for VPTQGemm {
315320 ] ;
316321 mmm. run ( m, n, non_linear) ?;
317322
318- Ok ( tvec ! ( out. into( ) ) )
319- }
323+ out
324+ } ;
325+ // force down cast for now
326+ let last_cdt = last. cast_to_dt ( input. datum_type ( ) ) ?. into_owned ( ) . into_tvalue ( ) ;
327+ Ok ( tvec ! ( last_cdt) )
320328 }
321329}
322330
0 commit comments