Skip to content

Commit 044f72b

Browse files
JulienBalianSonoskali
authored andcommitted
fix: vptq working on large model
1 parent 8658ec5 commit 044f72b

File tree

1 file changed

+21
-13
lines changed

1 file changed

+21
-13
lines changed

core/src/ops/vptq.rs

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -263,6 +263,8 @@ impl EvalOp for VPTQGemm {
263263
.into_tensor();
264264
}
265265

266+
let data_type = *fdtypes.iter().next().unwrap();
267+
266268
if enable_norm {
267269
qweight = (qweight.into_array::<f32>()? * weight_scale.into_array::<f32>()?
268270
+ weight_bias.into_array::<f32>()?)
@@ -278,33 +280,36 @@ impl EvalOp for VPTQGemm {
278280

279281
let &n = qweight.shape().last().unwrap();
280282

281-
let (&[m, k], out_shape, offset) = match ishape.len() {
283+
let (&[m, k], out_shape) = match ishape.len() {
282284
2 => {
283285
let &[m, k] = ishape else {
284-
bail!("unexpected rank: {:?}", input.len());
286+
bail!("unexpected rank: {:?}", ishape.len());
285287
};
286-
(&[m, k], vec![m, n], 0usize)
288+
(&[m, k], vec![m, n])
287289
}
288290
3 => {
289291
let &[b, m, k] = ishape else {
290-
bail!("unexpected rank: {:?}", input.len());
292+
bail!("unexpected rank: {:?}", ishape.len());
291293
};
292-
(&[m, k], vec![b, m, n], 1usize)
294+
(&[m, k], vec![b, m, n])
293295
}
294296
_ => {
295297
bail!("unexpected rank {:?}", ishape.len())
296298
}
297299
};
298300

299-
let mmm = op.mmm(*fdtypes.iter().next().unwrap(), Some(m), Some(k), Some(n)).unwrap();
301+
let input_offset = input.rank() - 2;
302+
let weight_offset = qweight.rank() - 2;
300303

304+
let mmm = op.mmm(data_type, Some(m), Some(k), Some(n)).unwrap();
301305
let (pack_a, pack_b) = &mmm.packings()[0];
302-
let cstore = unsafe { mmm.c_view(0 + offset, 1 + offset) };
303306

304-
let a = pack_a.prepare_tensor(&input, 1 + offset, 0 + offset)?;
305-
let b = pack_b.prepare_tensor(&qweight, 0 + offset, 1 + offset)?;
306-
unsafe {
307-
let out = Tensor::uninitialized::<f32>(out_shape.iter().as_slice().try_into()?)?;
307+
let cstore = unsafe { mmm.c_view(input_offset, 1 + input_offset) };
308+
309+
let a = pack_a.prepare_tensor(&input, 1 + input_offset, input_offset)?;
310+
let b = pack_b.prepare_tensor(&qweight, weight_offset, 1 + weight_offset)?;
311+
let last = unsafe {
312+
let out = Tensor::uninitialized::<f32>(out_shape.iter().as_slice())?;
308313
let non_linear = &[
309314
FusedSpec::AddMatMul {
310315
a: tract_linalg::mmm::AsInputValue::Owned(a),
@@ -315,8 +320,11 @@ impl EvalOp for VPTQGemm {
315320
];
316321
mmm.run(m, n, non_linear)?;
317322

318-
Ok(tvec!(out.into()))
319-
}
323+
out
324+
};
325+
// force down cast for now
326+
let last_cdt = last.cast_to_dt(input.datum_type())?.into_owned().into_tvalue();
327+
Ok(tvec!(last_cdt))
320328
}
321329
}
322330

0 commit comments

Comments
 (0)