@@ -57,14 +57,21 @@ impl VPTQGemm {
5757 pre_shift_pack_tensor_shape. push ( 1 ) ;
5858
5959 let mut out = shift_right_zero_and_1 (
60- pack_tensor. clone ( ) . into_shape ( & pre_shift_pack_tensor_shape) ?. into ( ) ,
60+ pack_tensor
61+ . clone ( )
62+ . into_shape ( & pre_shift_pack_tensor_shape) ?
63+ . into ( ) ,
6164 wf. into ( ) ,
6265 ) ?;
6366
6467 let mut post_shift_pack_tensor_shape = pack_tensor_shape. clone ( ) ;
6568 let pval = post_shift_pack_tensor_shape. pop ( ) . unwrap ( ) ;
6669 post_shift_pack_tensor_shape. push ( 32 * pval) ;
67- out = out. into_tensor ( ) . clone ( ) . into_shape ( & post_shift_pack_tensor_shape) ?. into_tvalue ( ) ;
70+ out = out
71+ . into_tensor ( )
72+ . clone ( )
73+ . into_shape ( & post_shift_pack_tensor_shape) ?
74+ . into_tvalue ( ) ;
6875
6976 let pad_size = ( pack_tensor_shape. last ( ) . unwrap_or ( & 0 ) * 32 ) % ( index_bits * num_elements) ;
7077 if pad_size > 0 {
@@ -78,10 +85,15 @@ impl VPTQGemm {
7885 let auto = out. shape ( ) . last ( ) . unwrap ( ) / index_bits;
7986 post_pad_pack_tensor_shape. push ( auto) ;
8087 post_pad_pack_tensor_shape. push ( index_bits) ;
81- out = out. into_tensor ( ) . into_shape ( & post_pad_pack_tensor_shape) ?. into ( ) ;
88+ out = out
89+ . into_tensor ( )
90+ . into_shape ( & post_pad_pack_tensor_shape) ?
91+ . into ( ) ;
8292
8393 let wf1 = Tensor :: from (
84- Array1 :: from_iter ( 0 ..( index_bits as i32 ) ) . to_shape ( [ 1 , 1 , 1 , index_bits] ) ?. into_owned ( ) ,
94+ Array1 :: from_iter ( 0 ..( index_bits as i32 ) )
95+ . to_shape ( [ 1 , 1 , 1 , index_bits] ) ?
96+ . into_owned ( ) ,
8597 ) ;
8698
8799 out = shift_left ( ) . eval ( tvec ! ( out, wf1. into( ) ) ) ?. pop ( ) . unwrap ( ) ;
@@ -146,7 +158,7 @@ impl VPTQGemm {
146158 . into_shape ( & [ num_codebooks, remain, group_size, vector_len] ) ?
147159 . permute_axes ( & [ 0 , 1 , 3 , 2 ] ) ? // NOTE: costly in tract (applied in memory)
148160 . into_shape ( & [ num_codebooks, remain * vector_len, group_size] ) ?
149- . permute_axes ( & [ 1 , 0 , 2 ] ) ?// NOTE: costly in tract (applied in memory)
161+ . permute_axes ( & [ 1 , 0 , 2 ] ) ? // NOTE: costly in tract (applied in memory)
150162 . into_shape ( & [ vector_len * remain, num_codebooks * group_size] ) ?;
151163
152164 let dim0 = qweight. shape ( ) [ 0 ] ;
@@ -210,14 +222,27 @@ impl EvalOp for VPTQGemm {
210222 assert_eq ! ( outlier_centroids. rank( ) , 3 ) ;
211223 assert ! ( outlier_centroids. datum_type( ) . is_float( ) ) ;
212224 }
213- let _fdtypes = [ input. datum_type ( ) , centroids. datum_type ( ) , outlier_centroids. datum_type ( ) ] ;
225+ let _fdtypes = [
226+ input. datum_type ( ) ,
227+ centroids. datum_type ( ) ,
228+ outlier_centroids. datum_type ( ) ,
229+ ] ;
214230 let fdtypes = HashSet :: from ( _fdtypes) ;
215231 if fdtypes. len ( ) != 1 {
216- log:: warn!( "force cast centroids to be same type as input: {:?}" , input. datum_type( ) ) ;
232+ log:: warn!(
233+ "force cast centroids to be same type as input: {:?}" ,
234+ input. datum_type( )
235+ ) ;
217236 centroids = centroids. cast_to_dt ( input. datum_type ( ) ) ?. into_owned ( ) ;
218- outlier_centroids = outlier_centroids. cast_to_dt ( input. datum_type ( ) ) ?. into_owned ( ) ;
237+ outlier_centroids = outlier_centroids
238+ . cast_to_dt ( input. datum_type ( ) ) ?
239+ . into_owned ( ) ;
219240 }
220- let _fdtypes = [ input. datum_type ( ) , centroids. datum_type ( ) , outlier_centroids. datum_type ( ) ] ;
241+ let _fdtypes = [
242+ input. datum_type ( ) ,
243+ centroids. datum_type ( ) ,
244+ outlier_centroids. datum_type ( ) ,
245+ ] ;
221246 let fdtypes = HashSet :: from ( _fdtypes) ;
222247 assert ! ( fdtypes. len( ) == 1 , "mixed dtypes: {_fdtypes:?}" ) ;
223248
@@ -245,16 +270,23 @@ impl EvalOp for VPTQGemm {
245270 if enable_perm {
246271 let axis = 0 ;
247272 let dim = perm. shape ( ) [ 0 ] ;
248- let top_k = Topk { axis, largest : false , fallback_k : dim. into ( ) } ;
249- let invert_perm =
250- top_k. eval ( tvec ! ( perm. into_tvalue( ) , tensor0( dim as u16 ) . into( ) ) ) ?. remove ( 0 ) ;
273+ let top_k = Topk {
274+ axis,
275+ largest : false ,
276+ fallback_k : dim. into ( ) ,
277+ } ;
278+ let invert_perm = top_k
279+ . eval ( tvec ! ( perm. into_tvalue( ) , tensor0( dim as u16 ) . into( ) ) ) ?
280+ . remove ( 0 ) ;
251281 // TODO: manage case with quant dim == 'in' ?
252282 // if self.vector_quant_dim == "in":
253283 // assert True, "Not implemented"
254284 // qweight = qweight[invert_perm, :]
255285
256286 let perm_gather_axis = 1 ;
257- let gather_perm = Gather { axis : perm_gather_axis } ;
287+ let gather_perm = Gather {
288+ axis : perm_gather_axis,
289+ } ;
258290 qweight = gather_perm
259291 . eval ( tvec ! ( qweight. into( ) , invert_perm) ) ?
260292 . pop ( )
@@ -280,19 +312,9 @@ impl EvalOp for VPTQGemm {
280312
281313 let & n = qweight. shape ( ) . last ( ) . unwrap ( ) ;
282314
283- let ( & [ m, k] , out_shape) = match ishape. len ( ) {
284- 2 => {
285- let & [ m, k] = ishape else {
286- bail ! ( "unexpected rank: {:?}" , ishape. len( ) ) ;
287- } ;
288- ( & [ m, k] , vec ! [ m, n] )
289- }
290- 3 => {
291- let & [ b, m, k] = ishape else {
292- bail ! ( "unexpected rank: {:?}" , ishape. len( ) ) ;
293- } ;
294- ( & [ m, k] , vec ! [ b, m, n] )
295- }
315+ let ( m, k, out_shape) = match ishape {
316+ & [ m, k] => ( m, k, vec ! [ m, n] ) ,
317+ & [ b, m, k] => ( m, k, vec ! [ b, m, n] ) ,
296318 _ => {
297319 bail ! ( "unexpected rank {:?}" , ishape. len( ) )
298320 }
@@ -301,15 +323,25 @@ impl EvalOp for VPTQGemm {
301323 let input_offset = input. rank ( ) - 2 ;
302324 let weight_offset = qweight. rank ( ) - 2 ;
303325
326+ /* this would be better for Intel where there is no f16 support, but the kernel selection
327+ APIs are not up to the task (yet)
328+
329+ let acc_type = if tract_linalg::has_fp16() {
330+ f16::datum_type()
331+ } else {
332+ f32::datum_type()
333+ };
334+
335+ */
304336 let mmm = op. mmm ( data_type, Some ( m) , Some ( k) , Some ( n) ) . unwrap ( ) ;
305337 let ( pack_a, pack_b) = & mmm. packings ( ) [ 0 ] ;
306338
307339 let cstore = unsafe { mmm. c_view ( input_offset, 1 + input_offset) } ;
308340
309341 let a = pack_a. prepare_tensor ( & input, 1 + input_offset, input_offset) ?;
310342 let b = pack_b. prepare_tensor ( & qweight, weight_offset, 1 + weight_offset) ?;
311- let last = unsafe {
312- let out = Tensor :: uninitialized :: < f32 > ( out_shape. iter ( ) . as_slice ( ) ) ?;
343+ unsafe {
344+ let out = Tensor :: uninitialized_dt ( data_type , & out_shape) ?;
313345 let non_linear = & [
314346 FusedSpec :: AddMatMul {
315347 a : tract_linalg:: mmm:: AsInputValue :: Owned ( a) ,
@@ -319,12 +351,8 @@ impl EvalOp for VPTQGemm {
319351 FusedSpec :: Store ( cstore. wrap ( & out. view ( ) ) ) ,
320352 ] ;
321353 mmm. run ( m, n, non_linear) ?;
322-
323- out
324- } ;
325- // force down cast for now
326- let last_cdt = last. cast_to_dt ( input. datum_type ( ) ) ?. into_owned ( ) . into_tvalue ( ) ;
327- Ok ( tvec ! ( last_cdt) )
354+ Ok ( tvec ! ( out. into_tvalue( ) ) )
355+ }
328356 }
329357}
330358
0 commit comments