Skip to content

Commit 40d2e63

Browse files
committed
make vtpq work on top of available kernels
1 parent 044f72b commit 40d2e63

File tree

1 file changed

+62
-34
lines changed

1 file changed

+62
-34
lines changed

core/src/ops/vptq.rs

Lines changed: 62 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,21 @@ impl VPTQGemm {
5757
pre_shift_pack_tensor_shape.push(1);
5858

5959
let mut out = shift_right_zero_and_1(
60-
pack_tensor.clone().into_shape(&pre_shift_pack_tensor_shape)?.into(),
60+
pack_tensor
61+
.clone()
62+
.into_shape(&pre_shift_pack_tensor_shape)?
63+
.into(),
6164
wf.into(),
6265
)?;
6366

6467
let mut post_shift_pack_tensor_shape = pack_tensor_shape.clone();
6568
let pval = post_shift_pack_tensor_shape.pop().unwrap();
6669
post_shift_pack_tensor_shape.push(32 * pval);
67-
out = out.into_tensor().clone().into_shape(&post_shift_pack_tensor_shape)?.into_tvalue();
70+
out = out
71+
.into_tensor()
72+
.clone()
73+
.into_shape(&post_shift_pack_tensor_shape)?
74+
.into_tvalue();
6875

6976
let pad_size = (pack_tensor_shape.last().unwrap_or(&0) * 32) % (index_bits * num_elements);
7077
if pad_size > 0 {
@@ -78,10 +85,15 @@ impl VPTQGemm {
7885
let auto = out.shape().last().unwrap() / index_bits;
7986
post_pad_pack_tensor_shape.push(auto);
8087
post_pad_pack_tensor_shape.push(index_bits);
81-
out = out.into_tensor().into_shape(&post_pad_pack_tensor_shape)?.into();
88+
out = out
89+
.into_tensor()
90+
.into_shape(&post_pad_pack_tensor_shape)?
91+
.into();
8292

8393
let wf1 = Tensor::from(
84-
Array1::from_iter(0..(index_bits as i32)).to_shape([1, 1, 1, index_bits])?.into_owned(),
94+
Array1::from_iter(0..(index_bits as i32))
95+
.to_shape([1, 1, 1, index_bits])?
96+
.into_owned(),
8597
);
8698

8799
out = shift_left().eval(tvec!(out, wf1.into()))?.pop().unwrap();
@@ -146,7 +158,7 @@ impl VPTQGemm {
146158
.into_shape(&[num_codebooks, remain, group_size, vector_len])?
147159
.permute_axes(&[0, 1, 3, 2])? // NOTE: costly in tract (applied in memory)
148160
.into_shape(&[num_codebooks, remain * vector_len, group_size])?
149-
.permute_axes(&[1, 0, 2])?// NOTE: costly in tract (applied in memory)
161+
.permute_axes(&[1, 0, 2])? // NOTE: costly in tract (applied in memory)
150162
.into_shape(&[vector_len * remain, num_codebooks * group_size])?;
151163

152164
let dim0 = qweight.shape()[0];
@@ -210,14 +222,27 @@ impl EvalOp for VPTQGemm {
210222
assert_eq!(outlier_centroids.rank(), 3);
211223
assert!(outlier_centroids.datum_type().is_float());
212224
}
213-
let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()];
225+
let _fdtypes = [
226+
input.datum_type(),
227+
centroids.datum_type(),
228+
outlier_centroids.datum_type(),
229+
];
214230
let fdtypes = HashSet::from(_fdtypes);
215231
if fdtypes.len() != 1 {
216-
log::warn!("force cast centroids to be same type as input: {:?}", input.datum_type());
232+
log::warn!(
233+
"force cast centroids to be same type as input: {:?}",
234+
input.datum_type()
235+
);
217236
centroids = centroids.cast_to_dt(input.datum_type())?.into_owned();
218-
outlier_centroids = outlier_centroids.cast_to_dt(input.datum_type())?.into_owned();
237+
outlier_centroids = outlier_centroids
238+
.cast_to_dt(input.datum_type())?
239+
.into_owned();
219240
}
220-
let _fdtypes = [input.datum_type(), centroids.datum_type(), outlier_centroids.datum_type()];
241+
let _fdtypes = [
242+
input.datum_type(),
243+
centroids.datum_type(),
244+
outlier_centroids.datum_type(),
245+
];
221246
let fdtypes = HashSet::from(_fdtypes);
222247
assert!(fdtypes.len() == 1, "mixed dtypes: {_fdtypes:?}");
223248

@@ -245,16 +270,23 @@ impl EvalOp for VPTQGemm {
245270
if enable_perm {
246271
let axis = 0;
247272
let dim = perm.shape()[0];
248-
let top_k = Topk { axis, largest: false, fallback_k: dim.into() };
249-
let invert_perm =
250-
top_k.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?.remove(0);
273+
let top_k = Topk {
274+
axis,
275+
largest: false,
276+
fallback_k: dim.into(),
277+
};
278+
let invert_perm = top_k
279+
.eval(tvec!(perm.into_tvalue(), tensor0(dim as u16).into()))?
280+
.remove(0);
251281
// TODO: manage case with quant dim == 'in' ?
252282
// if self.vector_quant_dim == "in":
253283
// assert True, "Not implemented"
254284
// qweight = qweight[invert_perm, :]
255285

256286
let perm_gather_axis = 1;
257-
let gather_perm = Gather { axis: perm_gather_axis };
287+
let gather_perm = Gather {
288+
axis: perm_gather_axis,
289+
};
258290
qweight = gather_perm
259291
.eval(tvec!(qweight.into(), invert_perm))?
260292
.pop()
@@ -280,19 +312,9 @@ impl EvalOp for VPTQGemm {
280312

281313
let &n = qweight.shape().last().unwrap();
282314

283-
let (&[m, k], out_shape) = match ishape.len() {
284-
2 => {
285-
let &[m, k] = ishape else {
286-
bail!("unexpected rank: {:?}", ishape.len());
287-
};
288-
(&[m, k], vec![m, n])
289-
}
290-
3 => {
291-
let &[b, m, k] = ishape else {
292-
bail!("unexpected rank: {:?}", ishape.len());
293-
};
294-
(&[m, k], vec![b, m, n])
295-
}
315+
let (m, k, out_shape) = match ishape {
316+
&[m, k] => (m, k, vec![m, n]),
317+
&[b, m, k] => (m, k, vec![b, m, n]),
296318
_ => {
297319
bail!("unexpected rank {:?}", ishape.len())
298320
}
@@ -301,15 +323,25 @@ impl EvalOp for VPTQGemm {
301323
let input_offset = input.rank() - 2;
302324
let weight_offset = qweight.rank() - 2;
303325

326+
/* this would be better for Intel where there is no f16 support, but the kernel selection
327+
APIs are not up to the task (yet)
328+
329+
let acc_type = if tract_linalg::has_fp16() {
330+
f16::datum_type()
331+
} else {
332+
f32::datum_type()
333+
};
334+
335+
*/
304336
let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap();
305337
let (pack_a, pack_b) = &mmm.packings()[0];
306338

307339
let cstore = unsafe { mmm.c_view(input_offset, 1 + input_offset) };
308340

309341
let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?;
310342
let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?;
311-
let last = unsafe {
312-
let out = Tensor::uninitialized::<f32>(out_shape.iter().as_slice())?;
343+
unsafe {
344+
let out = Tensor::uninitialized_dt(data_type, &out_shape)?;
313345
let non_linear = &[
314346
FusedSpec::AddMatMul {
315347
a: tract_linalg::mmm::AsInputValue::Owned(a),
@@ -319,12 +351,8 @@ impl EvalOp for VPTQGemm {
319351
FusedSpec::Store(cstore.wrap(&out.view())),
320352
];
321353
mmm.run(m, n, non_linear)?;
322-
323-
out
324-
};
325-
// force down cast for now
326-
let last_cdt = last.cast_to_dt(input.datum_type())?.into_owned().into_tvalue();
327-
Ok(tvec!(last_cdt))
354+
Ok(tvec!(out.into_tvalue()))
355+
}
328356
}
329357
}
330358

0 commit comments

Comments
 (0)