Skip to content

Commit 4991383

Browse files
[PHI] add int4 weight only quant kernel, add int4 weight only permute kernel (#64094)
* Add int4 quantzie kernel and permute kernel * Update weight_quantize_kernel_gpu_impl.h * dont reshape it version * update kernel * fix int4 quant kernel * Update weight_quantize_kernel_gpu_impl.h * fix conflicts * fix int4 per channel quant row pack error * fix int4 dequant launch kernel * remove printf * add int4 gpucpu check * Update test_weight_only_linear.py * Update weight_dequantize_kernel.cu * fix compile error * fix * fix ci * recommit * fix code --------- Co-authored-by: yuanlehome <[email protected]>
1 parent 347bad6 commit 4991383

File tree

4 files changed

+448
-42
lines changed

4 files changed

+448
-42
lines changed

paddle/phi/kernels/funcs/weight_dequant_functor.h

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,
187187

188188
int warp_id = threadIdx.x / 32, lane_id = threadIdx.x % 32;
189189
int tile_id = blockIdx.x * blockDim.x / 32 + warp_id;
190-
// Every two rows of the original weights are interleaved into a row with
191-
// stride of 64, so if each thread processes 16 elements(for int8, we can use
192-
// ldg.128 to load weights), then every group of four adjacent threads will
193-
// alternately process two different row weights for example every 128
190+
// Every 4 rows of the original weights are interleaved into a row with
191+
// stride of 32, so if each thread processes 16 elements(for int8, we can use
192+
// ldg.128 to load weights), then every group of two adjacent threads will
193+
// alternately process four different row weights for example every 128
194194
// consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave
195195
// layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before
196196
// interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1
@@ -383,6 +383,7 @@ void WeightDequantize(const Context& dev_ctx,
383383
k,
384384
group_size);
385385
} else if (algo == "weight_only_int4" && group_size == -1) {
386+
k *= 2;
386387
grid.x /= 2;
387388
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
388389
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),
@@ -391,6 +392,7 @@ void WeightDequantize(const Context& dev_ctx,
391392
n,
392393
k);
393394
} else if (algo == "weight_only_int4" && group_size > 0) {
395+
k *= 2;
394396
grid.x /= 2;
395397
int4_weight_only_dequant<DataType><<<grid, block, 0, stream>>>(
396398
reinterpret_cast<const uint8_t*>(x.data<int8_t>()),

paddle/phi/kernels/gpu/weight_quantize_kernel.cu

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,40 @@ void WeightQuantizeKernel(const Context& dev_ctx,
5959
x.data<T>(),
6060
quanted_x.data<int8_t>(),
6161
scale->data<float>(),
62-
weight_shape);
62+
weight_shape,
63+
arch,
64+
algo);
6365
trans(dev_ctx, quanted_x, out, axis);
6466
} else if (algo == "weight_only_int8") {
6567
dev_ctx.template Alloc<T>(scale);
6668
weight_quant_gpu<T, Context>(dev_ctx,
6769
x.data<T>(),
6870
quanted_x.data<int8_t>(),
6971
scale->data<T>(),
70-
weight_shape);
72+
weight_shape,
73+
arch,
74+
algo);
7175
weight_permute_gpu<Context>(dev_ctx,
7276
quanted_x.data<int8_t>(),
7377
out->data<int8_t>(),
7478
weight_shape,
75-
arch);
79+
arch,
80+
algo);
7681
} else if (algo == "weight_only_int4") {
77-
PADDLE_FATAL(
78-
"Weight quant gpu kernel currently don't support weight_only_int4 "
79-
"algo, please use cpu version.");
82+
dev_ctx.template Alloc<T>(scale);
83+
weight_quant_gpu<T, Context>(dev_ctx,
84+
x.data<T>(),
85+
quanted_x.data<int8_t>(),
86+
scale->data<T>(),
87+
weight_shape,
88+
arch,
89+
algo);
90+
weight_permute_gpu<Context>(dev_ctx,
91+
quanted_x.data<int8_t>(),
92+
out->data<int8_t>(),
93+
weight_shape,
94+
arch,
95+
algo);
8096
} else {
8197
PADDLE_FATAL(
8298
"The algo must be in ['weight_only_int8', 'weight_only_int4', "

0 commit comments

Comments
 (0)