@@ -187,10 +187,10 @@ __global__ void int4_weight_only_dequant(const uint8_t* weight,
187187
188188 int warp_id = threadIdx.x / 32 , lane_id = threadIdx.x % 32 ;
189189 int tile_id = blockIdx.x * blockDim.x / 32 + warp_id;
190- // Every two rows of the original weights are interleaved into a row with
191- // stride of 64 , so if each thread processes 16 elements(for int8, we can use
192- // ldg.128 to load weights), then every group of four adjacent threads will
193- // alternately process two different row weights for example every 128
190+ // Every 4 rows of the original weights are interleaved into a row with
191+ // stride of 32 , so if each thread processes 16 elements(for int8, we can use
192+ // ldg.128 to load weights), then every group of two adjacent threads will
193+ // alternately process four different row weights for example every 128
194194 // consecutive int8 elements [128*i, 128*(i+1)-1] of row N under interleave
195195 // layout, the first 64 are from [64*i, 64*(i+1)-1] of row 2N before
196196 // interleaving, and the last 64 are from [64*i, 64*(i+1)-1] of row 2N+1
@@ -383,6 +383,7 @@ void WeightDequantize(const Context& dev_ctx,
383383 k,
384384 group_size);
385385 } else if (algo == " weight_only_int4" && group_size == -1 ) {
386+ k *= 2 ;
386387 grid.x /= 2 ;
387388 int4_weight_only_dequant<DataType><<<grid, block, 0 , stream>>>(
388389 reinterpret_cast <const uint8_t *>(x.data <int8_t >()),
@@ -391,6 +392,7 @@ void WeightDequantize(const Context& dev_ctx,
391392 n,
392393 k);
393394 } else if (algo == " weight_only_int4" && group_size > 0 ) {
395+ k *= 2 ;
394396 grid.x /= 2 ;
395397 int4_weight_only_dequant<DataType><<<grid, block, 0 , stream>>>(
396398 reinterpret_cast <const uint8_t *>(x.data <int8_t >()),
0 commit comments